In [41]:
import torch
import torch.nn as nn
import torchvision

import albumentations as A
from albumentations.pytorch import ToTensorV2

from dataset import SBDDataset

In [42]:
transform = A.Compose([A.Normalize(), ToTensorV2()])

trainset = SBDDataset('~/data/datasets/VOC/benchmark_RELEASE/dataset/', 'train', transform)
testset = SBDDataset('~/data/datasets/VOC/benchmark_RELEASE/dataset/', 'val', transform)

In [43]:
from torch.utils.data import DataLoader

bs = 1

trainloader = DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=2, pin_memory=True)
testloader = DataLoader(testset, batch_size=bs, shuffle=False, num_workers=2, pin_memory=True)

In [44]:
class FCN32(nn.Module):
    
    def __init__(self, num_classes):
        
        super(FCN32, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=17),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 4096, kernel_size=7, padding=3),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Conv2d(4096, 4096, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Conv2d(4096, num_classes, kernel_size=1)
        )

        self.upscore = nn.Upsample(mode='bilinear', scale_factor=32)

    def forward(self, x):
        out = self.features(x)
        out = self.classifier(out)
        out = self.upscore(out)
        offset_h = (out.size(-2) - x.size(-2)) // 2
        offset_w = (out.size(-1) - x.size(-1)) // 2
        out = out[:,:,offset_h:offset_h+x.size(-2), offset_w:offset_w+x.size(-1)]
        return out


In [45]:
net = FCN32(21)

In [46]:
optimizer = torch.optim.SGD(net.parameters(), lr=1e-10, momentum=0.99)

In [47]:
criterion = nn.CrossEntropyLoss(reduction='sum')

In [49]:
from tqdm import tqdm
import torchfcn
import numpy as np

class Trainer():

    def train(self, model, optimizer, criterion, train_loader, val_loader, max_epoch):

        device = next(iter(model.parameters())).device
        best_mean_iu = 0
        n_class = 21

        for epoch in range(1, max_epoch+1):

            # train mode
            
            model.train()
            train_loss = 0
            train_acc = 0
            running_cnt = 0
    
            for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):

                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()

                score = model(data)
                loss = criterion(score, target.squeeze(1).type(torch.LongTensor))
                loss
                loss_data = loss.data.item()
                if np.isnan(loss_data):
                    raise ValueError('loss is nan while training')
                    
                train_loss += loss_data/ len(data)

                loss.backward()
                optimizer.step()

#                 metrics = []
                lbl_pred = score.argmax(dim=1)
                lbl_true = target
                train_acc += (lbl_pred == lbl_true).sum()
                running_cnt += data.size(-1) * data.size(-2)
                
#                 acc, acc_cls, mean_iu, fwavacc = torchfcn.utils.label_accuracy_score(lbl_true, lbl_pred, n_class=n_class)
#                 metrics.append((acc, acc_cls, mean_iu, fwavacc))
#                 metrics = np.mean(metrics, axis=0)
            
            train_loss /= len(train_loader)
            train_acc /= running_cnt
            print('train', train_loss, rain_acc )
            
            # validate mode
            
            self.model.eval()

            val_loss = 0
            label_trues, label_preds = [], []
            val_correct = 0
            running_cnt = 0

            for batch_idx, (data, target) in tqdm(enumerate(self.val_loader), total=len(self.val_loader)):
                data, target = data.to(device), target.to(device)
                with torch.no_grad():
                    score = self.model(data)

                loss = criterion(score, target.squeeze(1).type(torch.LongTensor))
                loss_data = loss.data.item()
                
                if np.isnan(loss_data):
                    raise ValueError('loss is nan while validating')

                val_loss += loss_data / len(data)
                
                lbl_pred = score.argmax(dim=1)
                lbl_true = target
                train_acc += (lbl_pred == lbl_true).sum()
                running_cnt += data.size(-1) * data.size(-2)

#                 lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
#                 lbl_true = target.data.cpu()

#                 for lt, lp in zip(lbl_true, lbl_pred):
#                     img, lt = self.val_loader.dataset.untransform(img, lt)
#                     label_trues.append(lt)
#                     label_preds.append(lp)

#             metrics = torchfcn.utils.label_accuracy_score(label_trues, label_preds, n_class)
            val_loss /= len(self.val_loader)
            print('val', val_loss)

In [50]:
trainer = Trainer()

In [51]:
trainer.train(net, optimizer, criterion, trainloader, testloader, 11)

  0%|          | 39/8498 [01:03<3:49:39,  1.63s/it]


KeyboardInterrupt: 