In [None]:
!pip install -Uq wandb tqdm torchsummary

In [1]:
import os
import getpass
key = os.environ.get("WANDB_API_KEY", None)
if not key:
    key = getpass.getpass("API Key: ")
    !export WANDB_API_KEY={key}
!wandb login {key}

API Key:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
import wandb
run = wandb.init(entity="adrishd", project="taco-baseline")
artifact = run.use_artifact('adrishd/taco/taco:pytorch', type='dataset')
artifact_dir = artifact.download()

In [None]:
import tacoloader
import torch
import torch.nn
import torch.nn.functional as F
import torchsummary
import time
import tqdm
import torchvision

In [None]:
h, w, c = 512, 512, 3
transform = torchvision.transforms.Resize(
    (h, w),
    torchvision.transforms.InterpolationMode.NEAREST)

In [None]:
dataset, collate_fn = tacoloader.load_dataset(artifact_dir, tacoloader.Environment.TORCH, transform_fn=transform)
split = 0.8
dataset_size = len(dataset)
indices = range(dataset_size)
train_indices = indices[:int(split * dataset_size)]
test_indices = indices[int(split * dataset_size) + 1:]
train_dataset = torch.utils.data.Subset(dataset, train_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=10,
    collate_fn=dataset.collate_fn,
    num_workers=6,
    shuffle=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    collate_fn=dataset.collate_fn,
    num_workers=6,
    shuffle=True)

In [None]:
def viz_mask(pred_mask, true_mask):
    pred_labels = torch.unique(pred_mask).cpu().numpy().tolist()
    predicted_class_labels = {
        i : x for i, x in enumerate(dataset.get_categories(pred_labels))
    }
    gt_labels = torch.unique(true_mask).cpu().numpy().tolist()
    ground_truth_labels = {
        i: x for i, x in enumerate(dataset.get_categories(gt_labels))
    }
    wandb_image = wandb.Image(sample.image, masks={
        "prediction": {
            "mask_data": pred_mask.squeeze(),
            "class_labels": predicted_class_labels
        },
        "ground_truth": {
            "mask_data": sample.segmentation,
            "class_labels": ground_truth_labels
        }
    })
    wandb.log("semantic_segmentation" : wandb_image)

## Model Design and Implementations
### Starter Code: Helper Modules for UNet Image Segmentation

In [None]:
# https://github.com/xiaopeng-liao/Pytorch-UNet/blob/master/unet/unet_parts.py
class double_conv(torch.nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, 3, padding=1),
            torch.nn.BatchNorm2d(out_ch),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, 3, padding=1),
            torch.nn.BatchNorm2d(out_ch),
            torch.nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(torch.nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(torch.nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = torch.nn.Sequential(
            torch.nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(torch.nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = torch.nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        #print('sizes',x1.size(),x2.size(),diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)
        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(torch.nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = torch.nn.Conv2d(in_ch, out_ch, 1)
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv(x)
        x = self.softmax(x)
        return x

In [None]:
class UNet(torch.nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 16)
        self.down1 = down(16, 32)
        self.down2 = down(32, 64)
        self.down3 = down(64, 128)
        self.down4 = down(128, 128)
        self.up1 = up(256, 64, bilinear=False)
        self.up2 = up(128, 32, bilinear=False)
        self.up3 = up(64, 16, bilinear=False)
        self.up4 = up(32, 16, bilinear=False)
        self.outc = outconv(16, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

In [None]:
unet = UNet(3, dataset.len_categories).cuda()
torchsummary.summary(unet, (c, h, w))

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(unet.parameters())

### Training and Logging Loop

In [None]:
bar = tqdm.tqdm(train_loader)
for data in bar:
    optim.zero_grad()
    segmask = unet(data.images.cuda())
    loss = loss_fn(segmask, data.masks.cuda().long())
    loss.backward()
    mask = torch.argmax(segmask[0:1], dim=1).detach()
    viz_mask(mask, sample.mask)
    bar.set_description("Loss: %f" % loss.detach().cpu())
    wandb.log({"loss": loss.detach().cpu()})
    optim.step()

In [None]:
wandb.finish()