In [None]:
!pip install -Uq wandb tqdm torchsummary

In [1]:
import wandb
# Define wandb username, project name and dataset path
wandb_username = "adrishd"
wandb_project = "taco-baseline"
dataset_artifact = 'adrishd/taco/taco:pytorch'

# Downloading dataset
# use root parameter in artifacts.download(root=<custom_path>)
# to specify download directory. else download in the current directory.
with wandb.init(entity=wandb_username, project=wandb_project) as run:
    artifact = run.use_artifact(dataset_artifact, type='dataset')
    artifact_dir = artifact.download()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwandb_fc[0m (use `wandb login --relogin` to force relogin)
  warn("The `IPython.html` package has been deprecated since IPython 4.0. "


[34m[1mwandb[0m: Downloading large artifact taco:pytorch, 2507.15MB. 1503 files... Done. 0:0:0


In [2]:
import tacoloader
import torch
import torch.nn
import torch.nn.functional as F
import torchsummary
import time
import tqdm
import torchvision

In [3]:
h, w, c = 512, 512, 3 # height, width and channel of images
# Use torchvision.transpose.Compose to compose multiple transformations
# together. Refer to: https://pytorch.org/vision/stable/transforms.html
transform = torchvision.transforms.Resize(
        (h, w),
        torchvision.transforms.InterpolationMode.NEAREST)

In [4]:
# Constants in the training pipeline
train_batch_size = 10
test_batch_size = 1
split = 0.8

In [5]:
dataset, collate_fn = tacoloader.load_dataset(artifact_dir, tacoloader.Environment.TORCH, transform_fn=transform)
# Splitting Dataset to 80%-20% for training and testing
dataset_size = len(dataset)
indices = range(dataset_size)
train_indices = indices[:int(split * dataset_size)]
test_indices = indices[int(split * dataset_size) + 1:]
train_dataset = torch.utils.data.Subset(dataset, train_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)

NOTE! Installing ujson may make loading annotations faster.
creating index...
index created!


In [6]:
# Creating Data Loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=train_batch_size,
    collate_fn=dataset.collate_fn,
    num_workers=6,
    shuffle=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=test_batch_size,
    collate_fn=dataset.collate_fn,
    num_workers=6,
    shuffle=True)

In [7]:
def viz_mask(image, pred_mask, true_mask):
    # Visualize segmentation mask on W&B dashboard
    # image: torch tensor of dim [c, h, w]
    # pred_mask: detached torch tensor of dim [h, w]
    # true_mask: torch tensor of dim [h, w]
    pred_labels = torch.unique(pred_mask).cpu().numpy().tolist()
    predicted_class_labels = {
        i : x for i, x in enumerate(dataset.get_categories(pred_labels))
    }
    gt_labels = torch.unique(true_mask).cpu().numpy().tolist()
    ground_truth_labels = {
        i: x for i, x in enumerate(dataset.get_categories(gt_labels))
    }
    wandb_image = wandb.Image(image.cpu(), masks={
        "prediction": {
            "mask_data": pred_mask.squeeze().cpu().numpy(),
            "class_labels": predicted_class_labels
        },
        "ground_truth": {
            "mask_data": true_mask.cpu().numpy(),
            "class_labels": ground_truth_labels
        }
    })
    wandb.log({"semantic_segmentation" : wandb_image})

## Model Design and Implementations
### Starter Code: Helper Modules for UNet Image Segmentation

In [8]:
# Helper function for getting activation functions
# from torch.nn given the function name.
# activations with inplace operations, are enabled
# by default.
import inspect
import functools
def get_activation_fn(fn_name):
    fn = getattr(torch.nn, fn_name)
    isinplace = "inplace" in inspect.signature(fn).parameters
    if isinplace:
        fn = functools.partial(fn, inplace=True)
    return fn

In [9]:
# Dummy baseline UNet model based on:
# https://github.com/xiaopeng-liao/Pytorch-UNet/blob/master/unet/unet_parts.py
class double_conv(torch.nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch, activation_fn_name):
        super(double_conv, self).__init__()
        activation_fn = get_activation_fn(activation_fn_name)
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, 3, padding=1),
            torch.nn.BatchNorm2d(out_ch),
            activation_fn(),
            torch.nn.Conv2d(out_ch, out_ch, 3, padding=1),
            torch.nn.BatchNorm2d(out_ch),
            activation_fn()
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(torch.nn.Module):
    def __init__(self, in_ch, out_ch, activation_fn):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch, activation_fn)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(torch.nn.Module):
    def __init__(self, in_ch, out_ch, activation_fn):
        super(down, self).__init__()
        self.mpconv = torch.nn.Sequential(
            torch.nn.MaxPool2d(2),
            double_conv(in_ch, out_ch, activation_fn)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(torch.nn.Module):
    def __init__(self, in_ch, out_ch, activation_fn, bilinear=True):
        super(up, self).__init__()
        if bilinear:
            self.up = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = torch.nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch, activation_fn)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(torch.nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = torch.nn.Conv2d(in_ch, out_ch, 1)
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv(x)
        x = self.softmax(x)
        return x

In [10]:
class UNet(torch.nn.Module):
    def __init__(self, n_channels, n_classes, config):
        super(UNet, self).__init__()
        mid_channels = config["unet_channels"]
        activation_fn = config["activation_fn"]
        self.inc = inconv(n_channels,
                          mid_channels,
                          activation_fn)
        self.down1 = down(mid_channels,
                          mid_channels * 2,
                          activation_fn)
        self.down2 = down(mid_channels * 2,
                          mid_channels * 4,
                          activation_fn)
        self.down3 = down(mid_channels * 4,
                          mid_channels * 8,
                          activation_fn)
        self.down4 = down(mid_channels * 8,
                          mid_channels * 8,
                          activation_fn)
        self.up1 = up(mid_channels * 16,
                      mid_channels * 4,
                      activation_fn,
                      bilinear=config["bilinear"])
        self.up2 = up(mid_channels * 8,
                      mid_channels * 2,
                      activation_fn,
                      bilinear=config["bilinear"])
        self.up3 = up(mid_channels * 4,
                      mid_channels,
                      activation_fn,
                      bilinear=config["bilinear"])
        self.up4 = up(mid_channels * 2,
                      mid_channels,
                      activation_fn,
                      bilinear=config["bilinear"])
        self.outc = outconv(mid_channels, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

In [11]:
loss_fn = torch.nn.CrossEntropyLoss()

### Training, Logging, Finding Hyper Parameters

In [12]:
# Using wandb's hyperparameter optimization framework sweeps
# More information can be found here: https://docs.wandb.ai/guides/sweeps
sweep_config = {
    "name" : "sweep-params",
    "method" : "bayes",
    "metric" : {
        "name": "pixacc",
        "goal": "minimize"
    },
    "parameters" : {
        "epochs" : {
            "values": [1, 3, 5]
        },
        "lr" : {
            "min": 1e-4,
            "max": 1e-2
        },
        "activation_fn" : {
            "values" : ["ReLU", "LeakyReLU", "PReLU"]
        },
        "unet_channels" : {
            "values" : [8, 16, 32]
        },
        "bilinear" : {
            "values" : [True, False]
        },
    }
}
sweep_id = wandb.sweep(
    sweep_config,
    entity=wandb_username,
    project=wandb_project
)

Create sweep with ID: cmd01r6y
Sweep URL: https://wandb.ai/adrishd/taco-baseline/sweeps/cmd01r6y


In [13]:
def pixel_accuracy(pred, ground):
    eqmap = torch.eq(pred.cpu(), ground).int()
    return float(eqmap.sum()) / eqmap.numel()

In [16]:
def train():
    with wandb.init(entity=wandb_username, project=wandb_project) as run:
        config = wandb.config
        unet = UNet(3, dataset.len_categories, config).cuda()
        optim = torch.optim.Adam(unet.parameters(), lr=config["lr"])
        bar = tqdm.tqdm(train_loader, leave=False)
        for x in range(config["epochs"]):
            for data in bar:
                optim.zero_grad()
                segmask = unet(data.images.cuda())
                loss = loss_fn(segmask, data.masks.cuda().long())
                loss.backward()
                bar.set_description("Loss: %f" % loss.detach().cpu())
                wandb.log({"loss": loss.detach().cpu()})
                optim.step()
        test_bar = tqdm.tqdm(test_loader, leave=False)
        with torch.no_grad():
            for data in test_bar:
                segmask = unet(data.images.cuda())
                mask = torch.argmax(segmask, dim=1).detach().squeeze()
                acc = pixel_accuracy(mask, segmask)
                wandb.log({"pixacc": acc})
        # Draw one sample and visualize the mask for each sweep
        sample = test_dataset[0]
        segmask = unet(sample.image.unsqueeze(0).cuda())
        mask = torch.argmax(segmask, dim=1).detach().squeeze()
        viz_mask(sample.image, mask, sample.mask)

In [17]:
count = 5 # Run 5 sweeps
wandb.agent(
    sweep_id,
    function=train,
    count=count
)

[34m[1mwandb[0m: Agent Starting Run: l5hpwmjo with config:
[34m[1mwandb[0m: 	activation_fn: ReLU
[34m[1mwandb[0m: 	bilinear: False
[34m[1mwandb[0m: 	epochs: 5
[34m[1mwandb[0m: 	lr: 0.008043417664695626
[34m[1mwandb[0m: 	unet_channels: 32
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Loss: 3.820595:  10%|█         | 12/120 [00:11<00:59,  1.81it/s][34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
                                                                

0,1
loss,██▇▇▆▅▅▄▄▃▂▁

0,1
loss,3.8206
