In [1]:
import os
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

from collections import OrderedDict

import numpy as np
import matplotlib.pyplot as plt

use_cuda = torch.cuda.is_available()

root = './data'
if not os.path.exists(root):
    os.mkdir(root)
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

### Gradient Reversal Function

In [13]:
# gradient reversal
from torch.autograd import Function

# Autograd Function objects are what record operation history on tensors
# and define formulas for the forwawrd and backprop.

class GradientReversalFn(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        # Store context for backprop
        ctx.alpha = alpha
        
        # Forward pass is a no-op
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        # Backward pass is just to -alpha the gradient
        output = grad_output.neg() * ctx.alpha
        
        # Must return same number as inputs to forward()
        return output, None

### Model

In [58]:
class SeparableConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, bias=False):
        super(SeparableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, 
                                   groups=in_channels, bias=bias, padding=1)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 
                                   kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class DASemantic(nn.Module):
    def __init__(self, img_size, num_classes):
        super(DASemantic, self).__init__()
        self.img_size = img_size
        self.entry = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2),
            nn.BatchNorm2d(num_features=32, eps=1e-3, momentum=0.99),
            nn.ReLU()
        )
        self.block1 = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(32, 32, 3),
            nn.BatchNorm2d(num_features=32, eps=1e-3, momentum=0.99),
            nn.ReLU(),
            SeparableConv2d(32, 32, 3),
            nn.BatchNorm2d(num_features=32, eps=1e-3, momentum=0.99),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1)
        )
        self.residual1 = nn.Conv2d(32, 32, 1, stride=2)
        self.block2 = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(32, 64, 3),
            nn.BatchNorm2d(num_features=64, eps=1e-3, momentum=0.99),
            nn.ReLU(),
            SeparableConv2d(64, 64, 3),
            nn.BatchNorm2d(num_features=64, eps=1e-3, momentum=0.99),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1)
        )
        self.residual2 = nn.Conv2d(32, 64, 1, stride=2)
        self.block3 = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(64, 128, 3),
            nn.BatchNorm2d(num_features=128, eps=1e-3, momentum=0.99),
            nn.ReLU(),
            SeparableConv2d(128, 128, 3),
            nn.BatchNorm2d(num_features=128, eps=1e-3, momentum=0.99),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1)
        )
        self.residual3 = nn.Conv2d(64, 128, 1, stride=2)
        self.domain_classifier = nn.Sequential(
            nn.Linear(128*45*45, 100), nn.BatchNorm1d(100),
            nn.ReLU(True),
            nn.Linear(100,2),
            nn.LogSoftmax(dim=1)
        )
        self.upsample1 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(num_features=128, eps=1e-3, momentum=0.99),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(num_features=128, eps=1e-3, momentum=0.99),
            nn.UpsamplingNearest2d(scale_factor=2)
        )
        self.residual4 = nn.Sequential(
            nn.UpsamplingNearest2d(scale_factor=2),
            nn.Conv2d(128,128,1)
        )
        self.head = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(128, 64, 3),
            nn.BatchNorm2d(num_features=64, eps=1e-3, momentum=0.99),
            nn.ReLU(),
            SeparableConv2d(64, 32, 3),
            nn.BatchNorm2d(num_features=32, eps=1e-3, momentum=0.99),
            nn.Conv2d(32, num_classes, 3)
        )
        
    def forward(self, x, grl_lambda):
        # Handle single-channel input by expanding singleton dim
        x = x.expand(x.data.shape[0], 3, self.img_size, self.img_size)
        
        # Entry block
        x = self.entry(x)
        previous_block_activation = x
        
        # Feature Depth
        # block 1
        x = self.block1(x)
        #print('Post Block1: ' + str(x.shape))
        previous_block_activation = self.residual1(previous_block_activation)
        #print('Post Residual1: ' + str(previous_block_activation.shape))
        x = torch.add(x, previous_block_activation)
        #print('Post add1: ' + str(x.shape))
        previous_block_activation = x
        # block 2
        x = self.block2(x)
        #print('Post Block2: ' + str(x.shape))
        previous_block_activation = self.residual2(previous_block_activation)
        #print('Post Residual2: ' + str(previous_block_activation.shape))
        x = torch.add(x, previous_block_activation)
        #print('Post add2: ' + str(x.shape))
        previous_block_activation = x
        # block 3
        x = self.block3(x)
        #print('Post Block3: ' + str(x.shape))
        previous_block_activation = self.residual3(previous_block_activation)
        #print('Post Residual3: ' + str(previous_block_activation.shape))
        x = torch.add(x, previous_block_activation)
        #print('Post add3: ' + str(x.shape))
        previous_block_activation = x
        
        # Gradient Reversal
        features = x.view(-1, 128 * 45 * 45)
        #print('GRL Feature vector: ' + str(features.shape))
        reverse_features = GradientReversalFn.apply(features, grl_lambda)
        #print('GRL Reverse Feature vector: ' + str(reverse_features.shape))
        
        # Upsampling
        # upsample 1
        x = self.upsample1(x)
        #print('Post UpSample1: ' + str(x.shape))
        previous_block_activation = self.residual4(previous_block_activation)
        #print('Post Residual4: ' + str(previous_block_activation.shape))
        x = torch.add(x, previous_block_activation)
        #print('Post add4: ' + str(x.shape))
        
        # Head
        output = self.head(x)
        #print('output: ' + str(output.shape))
        # Domain Pred
        domain_pred = self.domain_classifier(reverse_features)
        #print('domain_pred: ' + str(domain_pred.shape))
        
        return output, domain_pred

model = DASemantic(720, 2).cuda()
batch = torch.rand(2,3,720,720).cuda()
print(model(batch, 1.)[0].shape)


        
        

torch.Size([2, 2, 88, 88])


### Datasets

In [168]:
import os
from glob import glob as glob
import cv2 as cv
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms as T
import matplotlib.pyplot as plt


class CustomImageDataset(Dataset):
    def __init__(self, img_dir, mask_dir, domain_label, img_transform=None, mask_transform=None):
        self.img_dir = glob(os.path.join(img_dir,"*"))
        self.mask_dir = [os.path.join(mask_dir,f"mask{os.path.splitext(os.path.basename(img_name))[0][3:]}.jpg") for img_name in glob(os.path.join(img_dir,"*"))]
        self.domain_label = torch.ones(len(glob(os.path.join(img_dir,"*")))) if domain_label else torch.zeros(len(glob(os.path.join(img_dir,"*"))))
        self.img_transform = img_transform
        self.mask_transform = mask_transform
        
    def __len__(self):
        return len(self.img_dir)
    
    def __getitem__(self, idx):
        image = cv.imread(self.img_dir[idx])
        image = self.img_transform(image)
        
        mask = cv.imread(self.mask_dir[idx])
        mask = self.mask_transform(mask)
    
        domain = self.domain_label[idx]
    
        return image, mask, domain
    

In [169]:
# Custom Transform

class Threshold(object):
    def __init__(self, threshold):
        assert isinstance(threshold, (int))
        self.threshold = threshold
        
    def __call__(self, img):
        img[img<=self.threshold] = 0
        img[img>self.threshold] = 255
        
        return img
        

In [175]:
im_dir = '/home/ubuntu/workspace/create_train_set/data/dataset_out_mask/images/train'
mask_dir = '/home/ubuntu/workspace/create_train_set/data/dataset_out_mask/masks/train'

transform_img = transforms.Compose(
    [
        T.ToTensor(),
        T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        T.Resize((720,720))
    ]
)

transform_mask = transforms.Compose(
    [
        Threshold(30),
        T.ToTensor(),
        #T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        T.Resize((88,88))
    ]
)

ds_source = CustomImageDataset(im_dir, mask_dir, 0, transform_img, transform_mask)


dl_source = DataLoader(VirtualDataset, batch_size=4, shuffle=True)

print(next(iter(dl_source)))

[tensor([[[[ 0.4516,  0.4186,  0.3509,  ...,  0.0027, -0.0640, -0.0827],
          [ 0.4772,  0.4281,  0.4232,  ...,  0.0029, -0.0421, -0.0483],
          [ 0.4602,  0.4163,  0.4906,  ...,  0.0045, -0.0473, -0.0312],
          ...,
          [ 1.9521,  1.9814,  1.9831,  ...,  0.4361,  0.5631,  0.5930],
          [ 2.0022,  1.9722,  1.9453,  ...,  0.5736,  0.5870,  0.6008],
          [ 1.9127,  1.9565,  2.0406,  ...,  0.6192,  0.6194,  0.5094]],

         [[ 0.5911,  0.5574,  0.4882,  ...,  0.1322,  0.0640,  0.0449],
          [ 0.6173,  0.5671,  0.5621,  ...,  0.1324,  0.0864,  0.0801],
          [ 0.5999,  0.5551,  0.6311,  ...,  0.1340,  0.0811,  0.0976],
          ...,
          [ 2.1252,  2.1551,  2.1569,  ...,  0.5753,  0.7052,  0.7357],
          [ 2.1763,  2.1456,  2.1182,  ...,  0.7159,  0.7296,  0.7437],
          [ 2.0848,  2.1296,  2.2156,  ...,  0.7625,  0.7627,  0.6502]],

         [[ 0.8107,  0.7771,  0.7083,  ...,  0.3539,  0.2860,  0.2669],
          [ 0.8368,  0.7868, 

### Training

In [27]:
lr = 1e-3
n_epochs = 1

# Setup optimizer
model = DASemantic(720, 2).cuda()
optimizer = optim.Adam(model.parameters(), lr)

# Two Loss Functions (pixelwise class and domain)
loss_fn_class = nn.CrossEntropyLoss(reduction="none")
loss_fn_domain = nn.NLLLoss()


In [None]:
batch_size = 32

# get data

# find number of batches to run for


In [None]:
model.train()
for epoch_idx in range(n_epochs):
    print(f'Epoch {epoch_idx+1:04d} / {n_epochs:04d}', end='\n=================\n')
    # source data iter
    # target data iter
    
    for batch_idx in range(max_batches):
        p = float(((batch_idx + epoch_idx * max_batches) / (n_epochs * max_batches))/2)
        grl_lambda = 2. / (1. + np.exp(-10*p)) - 1
        
        # Train on source domain
        X_s, y_s = next(dl_source_iter)
        X_s, y_s = X_s.cuda(), y_s.cuda()
        y_s_domain = torch.zeros(batch_size, dtype=torch.long) # generate source domain labels
        y_s_domain = y_s_domain.cuda()
        
        class_pred, domain_pred = model(X_s, grl_lambda)
        loss_s_label = loss_fn_class(class_pred, y_s)
        loss_s_domain = loss_fn_domain(domain_pred, y_s_domain)
        
        # Train on target domain
        X_t, _ = next(dl_target_iter) # ignore target domain class labels!
        X_t = X_t.cuda()
        y_t_domain = torch.ones(batch_size, dtype=torch.long) # generate target domain labels
        y_t_domain = y_t_domain.cuda()
        
        _, domain_pred = model(X_t, grl_lambda)
        loss_t_domain = loss_fn_domain(domain_pred, y_t_domain)
        
        # Calculate total loss
        loss = loss_t_domain + loss_s_domain + loss_s_label
        loss.backward()
        optimizer.step()
    
        if batch_idx % 10 == 0:
            print(f'[{batch_idx+1}/{max_batches}] '
                  f'class_loss: {loss_s_label.item():.4f} ' f's_domain_loss: {loss_s_domain.item():.4f} '
                  f't_domain_loss: {loss_t_domain.item():.4f} ' f'grl_lambda: {grl_lambda:.3f} '
                 )
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    

In [6]:
a = torch.rand(2,2)*100
b = torch.rand(2,2)*100
c = torch.add(a,b)
print(a, " + ", b)
print(c)

tensor([[35.4760, 99.9222],
        [43.7156, 72.4414]])  +  tensor([[59.5805, 27.1421],
        [19.7191, 31.2870]])
tensor([[ 95.0565, 127.0644],
        [ 63.4347, 103.7284]])
