In [1]:
import torch 
from torch import optim
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import os
import time

import sys

sys.path.append('..')
from datasets import COCOADatasetGraph, COCOADatasetDual

from Trainer import MaskTrainer, DualMaskTrainer
from UNet import UNet, UNetDual

  from .collection import imread_collection_wrapper


In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
print(device)
DUAL = True
classification = True # for graph
hingeloss = True
C = 8       # Number of channels in graph

cuda:0


# Data

In [3]:
def COCOADataset(phase, transform=None, classification=classification):
    root_dict = {'train': "../data/COCOA/train2014", 'val': "../data/COCOA/val2014"}
    img_root = root_dict[phase]
    annot_path = "../data/COCOA/annotations/COCO_amodal_{}2014.json".format(phase)
    graph_root = "../data/COCOA/pixel_graphs_{}/".format(phase)

    if DUAL:
        return COCOADatasetDual(annot_path, img_root, graph_root, transform=transform, classification=classification)
    else:
        return COCOADatasetGraph(annot_path, img_root, graph_root, transform=transform, classification=classification)

# Training

In [4]:
# if hingeloss:
#     classification = False 

if DUAL:
    model = UNetDual(seg_out_channels=91, graph_out_channels=C, classification=classification)
else:
    model = UNet(out_channels=C, classification=classification)
model = model.to(device)

24


In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomCrop(128)
])

# if hingeloss:
#     classification = True

train_dataset = COCOADataset('train', transform=transform, classification=classification)
val_dataset = COCOADataset('val', transform=transform, classification=classification)

In [6]:
lr = 0.01
momentum = 0.9
weight_decay = 1e-4
time = time.time()
start_epoch = 0

folder = f'../tests/multi_hinge_lr{lr}_momentum{momentum}_weight_decay{weight_decay}_time{time}'

optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

if classification:
    loss2 = nn.CrossEntropyLoss(weight=torch.tensor([.45, .1, .45]).to(device)) if classification else nn.MSELoss()
else:
    loss2 = nn.MultiLabelMarginLoss()
    
if DUAL:
    trainer = DualMaskTrainer(device, model, train_dataset, val_dataset, 32, optimizer, scheduler=scheduler, 
                              losses=[nn.CrossEntropyLoss(), loss2], alpha=0.5, num_workers=2, 
                              exp_name='dual', checkpoint_dir=folder)
else:
    trainer = MaskTrainer(device, model, train_dataset, val_dataset, 32, optimizer, scheduler=scheduler, 
                          loss=loss2, num_workers=2, exp_name='unet', checkpoint_dir=folder, hingeloss=True, classification=classification)

In [7]:
train_log = trainer.train(100, checkpoint=True, start_epoch=start_epoch)

  0%|          | 0/79 [00:00<?, ?it/s]

torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])


  3%|▎         | 2/79 [00:10<05:36,  4.37s/it]

torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])
torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])


  4%|▍         | 3/79 [00:19<08:00,  6.33s/it]

torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])


  6%|▋         | 5/79 [00:26<06:03,  4.92s/it]

torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])


  8%|▊         | 6/79 [00:26<04:06,  3.38s/it]

torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])
torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])


 10%|█         | 8/79 [00:35<04:16,  3.61s/it]

torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])


 11%|█▏        | 9/79 [00:44<06:12,  5.32s/it]

torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])


 13%|█▎        | 10/79 [00:45<04:21,  3.80s/it]

torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])


 14%|█▍        | 11/79 [00:54<06:09,  5.43s/it]

torch.Size([32, 115, 128, 128])
torch.Size([24, 32, 128, 128])


 14%|█▍        | 11/79 [00:54<05:36,  4.94s/it]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/kenchen10/.conda/envs/py38/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-7-2e4e35ea589d>", line 1, in <module>
    train_log = trainer.train(100, checkpoint=True, start_epoch=start_epoch)
  File "/home/kenchen10/spatial-segmentation/U-Net/Trainer.py", line 258, in train
    output_seg, output_graph = self.net(_in)
  File "/home/kenchen10/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/kenchen10/spatial-segmentation/U-Net/UNet.py", line 160, in forward
    enc4 = self.encoder4(self.pool3(enc3))
  File "/home/kenchen10/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/kenchen10/.conda/envs/py38/lib/python3.8/site-p

TypeError: object of type 'NoneType' has no len()

## Loading in Model

In [None]:
folder = 'multi_hinge_lr0.1_momentum0.9_weight_decay0.0001unet_1620001280.298803' 
epoch = 'checkpoint_optim_0_1620001617.0382125'

PATH = '../tests/' + folder + '/' + epoch

def load_checkpoint(model, optimizer, scheduler, filename):
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['start_epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optim'])
        scheduler.load_state_dict(checkpoint['sched'])
        losslogger = checkpoint['log']
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['start_epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, scheduler, start_epoch, losslogger

model, optimizer, scheduler, start_epoch, log = load_checkpoint(model, optimizer, scheduler, PATH)
model.eval()
model = model.to(device)

# Results

In [None]:
def plot_predictions(img, graph_pred, graph_target):
    plt.subplot(1, 3, 1)
    plt.imshow(img.permute(1, 2, 0))
    plt.title("Input")
    plt.subplot(1, 3, 2)
    plt.imshow(graph_target.transpose(1, 2, 0)[:, :, 0])
    plt.title("GT Graph 0")
    plt.subplot(1, 3, 3)
#     print(graph_pred.shape)
    if classification:
#         print(np.sum(graph_pred))
        graph_pred = np.squeeze(graph_pred[:, 0, :, :])
#         print(graph_pred.shape)
#         plt.imshow(graph_pred.transpose(1,2,0))
#         print(np.sum(graph_pred))
#         print(np.max(graph_pred))
#         print(np.min(graph_pred))
#         for i in range(3):
#             for j in range(1):
#                 for k in range(5):
#                     print(graph_pred[i, j, k])
#             print('\n')
        graph_pred = np.squeeze(np.argmax(graph_pred, axis=0))
        plt.imshow(graph_pred - 1)
    else:
        plt.imshow(graph_pred.transpose(1, 2, 0)[:, :, 0])
        plt.title("Pred Graph 0")
    plt.show()
    
def plot_predictions_dual(img, graph_pred, graph_target, seg_pred, seg_target):
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 5, 1)
    plt.imshow(img.permute(1, 2, 0))
    plt.title("Input")
    plt.subplot(1, 5, 2)
    plt.imshow(graph_target.transpose(1, 2, 0)[:, :, 0])
    plt.title("GT Graph 0")
    plt.subplot(1, 5, 3)
    plt.imshow(graph_pred.transpose(1, 2, 0)[:, :, 0])
    plt.title("Pred Graph 0")
    plt.subplot(1, 5, 4)
    plt.imshow(seg_target.transpose(1, 2, 0))
    plt.title("GT Segmentation")
    plt.subplot(1, 5, 5)
    plt.imshow(seg_pred.transpose(1, 2, 0)[:, :, 0])
    plt.title("Pred Segmentation")
    plt.show()

In [None]:
arr = np.array([[4, 5, 6, 8], [1, 2, 3, 4]])
arr2 = np.array([[5,1,2,3], [4,3,2,1]])
arr3 = np.array([[8, 9, 1, 2], [1, 4, 3, 2]])
combined = np.stack((arr, arr2, arr3))
print(combined)
np.argmax(combined, axis=0)

In [None]:
batch = next(iter(trainer.val_loader))
if DUAL:
    img, (seg_gt, graph_gt), _ = batch
    graph_gt = graph_gt.numpy()
    seg_gt = seg_gt.numpy()
    seg_pred, graph_pred = trainer.evaluate(batch)
else:
    img, gt, _ = batch
    gt = gt.numpy()
    graph_pred = trainer.evaluate(batch)

In [None]:
for i in range(len(img)):
# for i in [0]:
    if DUAL:
        plot_predictions_dual(img[i], graph_pred[i], graph_gt[i], seg_pred[i], seg_gt[i])
    else:
        plot_predictions(img[i], graph_pred[i], gt[i])