In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import datasets, transforms, models


from pycocotools.coco import COCO
from pycocotools.mask import *
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
device = torch.device("cuda: 0" if torch.cuda.is_available() else "cpu")

In [3]:
IMG_SIZE = 224

In [4]:
data_transform = {
    "train": transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
}

In [5]:
DATA_DIR = {
    "coco": {
        "img": {
            "train": "data/coco/train2017"},
        "instances": {
            "train": "data/coco/annotations/instances_train2017.json"}}
}

In [6]:
class COCODataset(data.Dataset):
    def __init__(self, image_dir, annotation_dir, transform=None):
        self.img_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        self.coco = COCO(self.annotation_dir)
        self.cats = self.coco.cats
        allImgIds = self.coco.getImgIds()
        self.imgs = self.coco.loadImgs(allImgIds)
        
    def __getitem__(self, idx):
        
        
        img = self.imgs[idx]
        
        file_name = img["file_name"]
        img_idx = img["id"]
        image = Image.open(os.path.join(self.img_dir, file_name))
        if self.transform is not None:
            image = self.transform(image)
            
        annIds = self.coco.getAnnIds(imgIds=img_idx, iscrowd=None)
        anns = self.coco.loadAnns(annIds)
        
        img_width = img["width"]
        img_height = img["height"]
        
        masks = np.zeros((img_height,img_width))
        categories = []
        for i in range(len(anns)):
            mask = self.coco.annToMask(anns[i])
            categories.append(anns[i]["category_id"])
            mask = mask*categories[-1]                          # (shape: (h,w))
            masks =  np.array([[max(d,e) for d,e in zip(i,j)] for i,j in zip(masks,mask)])      # (shape: (h,w))
            
        masks = torch.LongTensor(masks)
        
        item = {
                "filename": file_name,
                "image": image,
                "width":  img_width,
                "height": img_height,
                "mask": masks,
                "category": categories}
        
        return item
        
    def imshow(self, item):
        image = Image.open(os.path.join(self.img_dir, item["filename"]))
        plt.imshow(image)
        seg_imageGray=np.zeros((item['height'],item['width'])).astype(np.uint8)
        seg_imageRGB=np.zeros((item['height'],item['width'],3)).astype(np.uint8)
        for i in range(len(item["mask"])):
            img_cats = self.cats[item["category"][i]]["name"]
            seg_image = item["mask"][i]
            seg_image=(seg_image-(seg_image&seg_imageGray))
            imgRGB = np.zeros((item['height'],item['width'],3))
            color_mask = np.random.random((1, 3)).tolist()[0]
            imgRGB[:,:,0]=((seg_imageGray|seg_image)==1)*color_mask[0]
            imgRGB[:,:,1]=((seg_imageGray|seg_image)==1)*color_mask[1]
            imgRGB[:,:,2]=((seg_imageGray|seg_image)==1)*color_mask[2]
            seg_imageRGB=seg_imageRGB+imgRGB
            seg_imageRGB = np.clip(seg_imageRGB,0,1)
            print(img_cats)
        plt.imshow(seg_imageRGB, cmap="jet", alpha=0.4)
        plt.axis("off")
        plt.show()
        
    def __len__(self):
        return len(self.imgs)

In [7]:
trainset = COCODataset(DATA_DIR["coco"]["img"]["train"],
                       DATA_DIR["coco"]["instances"]["train"], transform=data_transform["train"])

NUM_CLASSES = len(trainset.cats)

loading annotations into memory...
Done (t=16.80s)
creating index...
index created!


In [8]:
BATCH_SIZE = 1
trainloader = data.DataLoader(trainset,batch_size = BATCH_SIZE, num_workers=0)

In [9]:
class FCN(nn.Module):
    def __init__(self, NUM_CLASSES):
        super(FCN, self).__init__()
        self.fcn = models.segmentation.fcn_resnet101(pretrained=True,  aux_loss=None).train()
        modules2 = list(self.fcn.children())[1]
        modules3 = list(self.fcn.children())[2]
        
       
        for params in self.fcn.parameters():
            params.requires_grad = False
            
        modules2[4] = nn.Conv2d(512,NUM_CLASSES, kernel_size=(1, 1), stride=(1, 1))
        modules3[4] = nn.Conv2d(256,NUM_CLASSES, kernel_size=(1, 1), stride=(1, 1))

        self.fcn.classifier = nn.Sequential(*modules2)
        self.fcn.aux_classifier = nn.Sequential(*modules3)
        
    def forward(self, x):
        x = self.fcn(x)
        return x

In [10]:
model = FCN(NUM_CLASSES).to(device)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 61,600 trainable parameters


In [11]:
print(model)

FCN(
  (fcn): FCN(
    (backbone): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(in

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
criterion = nn.CrossEntropyLoss()

In [13]:
def predict_accuracy(output, masks):
    prediction = torch.max(output, dim = 1)[1]
    prediction = prediction.reshape(-1)
    masks = masks.reshape(-1)
    correct = (prediction==masks).sum()
    total = len(prediction)
    accuracy = correct*1.0/total
    return accuracy

In [17]:
from tqdm.notebook import tqdm
def train_process(model, optimizer, criterion, trainloader):
    model.train()
    for idx,batch in tqdm(enumerate(trainloader)):
        img = batch["image"].to(device)
        masks = batch["mask"].squeeze(1).to(device)
        cats = batch["category"]
        output = model(img)["out"]
        
        loss = criterion(output, masks)
        acc = predict_accuracy(output, masks)
        loss.backward()
        optimizer.step()
        
        del img,masks
        
        
        if idx%300 == 0:
            print(f"Iteration: {idx} | Loss: {loss.data} | Accuracy: {acc}")
    
    return loss, acc

In [18]:
def train(model, optimizer, criterion, trainloader):
    EPOCHS = 5
    for epoch in range(EPOCHS):
        train_loss, train_acc = train_process(model, optimizer, criterion, trainloader)
        print(f"Epoch: {epoch} | Loss: {train_loss.data} | Accuracy: {train_acc}")

In [19]:
TRAIN_DATA = 1
if TRAIN_DATA:
    train(model,optimizer, criterion, trainloader)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

RuntimeError: CUDA error: device-side assert triggered