In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from coco_stuff import COCOStuff
from unet import UNet
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from torchvision.transforms import Normalize, ToTensor, Resize, Lambda
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
DATA = Path("data")
IMG = DATA / "images"
IMG_TRAIN = IMG / "train2017"
IMG_VAL = IMG / "val2017"

ANNOT = DATA / "annotations"
ANNOT_TRAIN = ANNOT / "stuff_train2017.json"
ANNOT_VAL = ANNOT / "stuff_val2017.json"

In [None]:
INPUT_SIZE = 256, 256 #(512, 512)

In [None]:
image_transforms = [
    ToTensor(), 
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    Lambda(
        lambda img: F.interpolate(img.unsqueeze(0), INPUT_SIZE, mode="bilinear", align_corners=True).squeeze()
    )

]

mask_transforms = [
    Lambda(lambda mask: torch.FloatTensor(np.expand_dims(mask, axis=2)).permute(2, 0, 1)),
    Lambda(
        lambda mask: F.interpolate(
            mask.unsqueeze(0), INPUT_SIZE, mode="nearest"
        ).squeeze()
    ),
    Lambda(lambda mask: mask.squeeze().long())    
]

In [None]:
ds_val.coco.loadAnns()

In [None]:
# ds_train = COCOStuff(
#     images_path=IMG_TRAIN,
#     annotations_json=ANNOT_TRAIN,
#     transformations=image_transforms,
#     target_transformations=mask_transforms
# )


ds_val = COCOStuff(
    images_path=IMG_VAL,
    annotations_json=ANNOT_VAL,
    transformations=image_transforms,
    target_transformations=mask_transforms
)


In [None]:
annotations = ds_val.coco.loadAnns(ds_val.coco.getAnnIds(139))
annotations[0]

In [None]:
ds_val.n_classes, ds_val.n_classes

In [None]:
ds_val.coco.loadCats(ds_val.coco.getCatIds())

In [None]:
mask

In [None]:
img, mask = ds_val.get_image_and_mask(ds_val._image_id(4145))
for i in np.unique(mask):
    cat = [
        c for c in ds_val.coco.loadCats(ds_val.coco.getCatIds())
        if c["id"] == i
    ]
    plt.title(cat)
    plt.imshow(img)
    plt.imshow(mask==i, alpha=0.7)
    plt.show()

In [None]:
np.unique(mask)

In [None]:
loader_train = DataLoader(ds_train, batch_size=4,shuffle=True)
loader_val = DataLoader(ds_val, batch_size=2)
iter_train = iter(loader_train)
iter_val = iter(loader_val)

In [None]:
model_base = UNet(n_classes=ds_train.n_classes)

In [None]:
model = deepcopy(model_base).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer= optim.Adam(model.parameters(), lr=4e-2) #, weight_decay=10e-5)

In [None]:
img, mask = next(iter(loader_train))
img = img.to(device)
model(img).size()

In [None]:
iters = 1
epochs = 1
loss_hist = []
acc_hist = []
loss_val_hist = []
acc_val_hist = []

for i in range(iters):
    model.train()
    ls = []
    acc = []
    epochbar = tqdm(range(epochs))
    for e in epochbar:
        X_train, y_train = next(iter_train)
        X_train, y_train = X_train.to(device), y_train.to(device)
        y_pred = model(X_train)
        _, logits = torch.max(y_pred, 1)
        train_loss = loss_fn(y_pred, y_train)
        train_accuracy = (logits == y_train).sum().item() / y_train.nelement()
        ls.append(train_loss.item())
        acc.append(train_accuracy)        
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        epochbar.set_description(
            f"iter: {i} | " 
            f"train_loss: {np.mean(ls)} | "
            f"train_acc: {np.mean(acc)}"
        )
    loss_hist.append(np.mean(ls))
    acc_hist.append(np.mean(acc))
    
    model.eval()
    ls = []
    acc = []
    for _ in range(10):
        X_val, y_val = next(iter_val)
        X_val, y_val = X_val.to(device), y_val.to(device)
        y_pred = model(X_val)
        _, logits = torch.max(y_pred, 1)
        val_loss = loss_fn(y_pred, y_val)
        val_accuracy = (logits == y_val).sum().item() / y_val.nelement()
        ls.append(val_loss.item())
        acc.append(val_accuracy)     
    loss_val_hist.append(np.mean(ls))
    acc_val_hist.append(np.mean(acc))
    print(
            f"val_loss: {loss_val_hist[-1]} | "
            f"val_acc: {acc_val_hist[-1]}"
        )


In [None]:
plt.title("loss")
plt.plot(range(len(loss_hist)), loss_hist, label="train")
plt.plot(range(len(loss_val_hist)), loss_val_hist,label="val")
plt.legend()
plt.show()

plt.title("accuracy")
plt.plot(range(len(acc_hist)), acc_hist, label="train")
plt.plot(range(len(acc_val_hist)), acc_val_hist,label="val")
plt.legend()
plt.show()

In [None]:
img, mask = ds_val[60]
_, mask_pred = torch.max(model(img.unsqueeze(0).to(device)), 1)
mask_pred = mask_pred.cpu()

print("accuracy", (mask.numpy() == mask_pred.numpy()).mean())
plt.imshow(img.permute(1, 2, 0))
plt.show()

plt.imshow(mask) #, alpha = 0.3)
plt.show()

# plt.imshow(img.permute(1, 2, 0))
plt.imshow(mask_pred.squeeze()) #, alpha = 0.3)
plt.show()
