In [1]:
import torch
import torch.nn as nn
import os
import sys
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt
# Pytorch
import torch
import torch.optim as optim
import torchvision.transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
if 'pwd' not in globals():
    pwd = os.getcwd()
os.chdir('..')
from modules.Dataset import BeeDataset
import modules.transforms as T
os.chdir(pwd)

In [2]:
CROP_DIM = 512
data_root = '../data/processed'

In [3]:
transforms = torchvision.transforms.Compose([
    T.RandomCropper(CROP_DIM),
    T.LRFlipper(),
    T.Rotator(),
    T.ToTensor(),
    T.Normalizer()
])

In [4]:
dataset = BeeDataset(data_root=data_root, transforms=transforms)

Loading paths...
Num paths loaded: 10


In [5]:
loader = DataLoader(dataset, batch_size=2)

In [6]:
def cprint(key, x):
    if VERBOSE:
        print(f"{key:15s} {x.detach().numpy().shape}")

In [7]:
class Unet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self._build_encoder()
        self._build_decoder()
        self.final_layer = nn.Conv2d(32*2, n_classes, 1, stride=1, padding=0)
        self.activation = nn.ReLU()
        self.final_activation = nn.Sigmoid()
    
    def _build_encoder(self):
        print('Building encoder')
        self.conv_1a = nn.Conv2d(
            in_channels=1,
            out_channels=32,
            kernel_size=(3,3),
            stride=1,
            padding=1
        )
        self.conv_1b = nn.Conv2d(32, 32, 3, stride=1, padding=1)
        
        self.pool_2 = nn.MaxPool2d(2)
        self.conv_2a = nn.Conv2d(32, 64, 3, stride=1, padding=1)
        self.conv_2b = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        
        self.pool_3 = nn.MaxPool2d(2)
        self.conv_3a = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.conv_3b = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        
        self.pool_4 = nn.MaxPool2d(2)
        self.conv_4a = nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.conv_4b = nn.Conv2d(256, 256, 3, stride=1, padding=1)
    
    def _build_decoder(self): 
        print('Building decoder')
        self.up_3 = nn.ConvTranspose2d(256, 128, 2, stride=2, padding=0)
        self.D_conv_3a = nn.Conv2d(128*2, 128, 3, stride=1, padding=1)
        self.D_conv_3b = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        
        self.up_2 = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
        self.D_conv_2a = nn.Conv2d(64*2, 64, 3, stride=1, padding=1)
        self.D_conv_2b = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        
        self.up_1 = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
        self.D_conv_1a = nn.Conv2d(32*2, 32, 3, stride=1, padding=1)
        self.D_conv_1b = nn.Conv2d(32, 32, 3, stride=1, padding=1)
     
        self.prev_state = None
        
    def forward(self, x):
        cprint('X', x)
        
        # Encoder
        out = self.activation(self.conv_1a(x))
        cprint("E1a", out)
        E1_out = self.activation(self.conv_1b(out))
        cprint("E1b", E1_out)
        
        out = self.pool_2(E1_out)
        cprint("pool_2", out)
        out = self.activation(self.conv_2a(out))
        cprint("E2a", out)
        E2_out = self.activation(self.conv_2b(out))
        cprint("E2b", E2_out)
        
        out = self.pool_3(E2_out)
        cprint("pool_3", out)
        out = self.activation(self.conv_3a(out))
        cprint("E3a", out)
        E3_out = self.activation(self.conv_3b(out))
        cprint("E3b", E3_out)
        
        out = self.pool_4(E3_out)
        cprint("pool_4", out)
        out = self.activation(self.conv_4a(out))
        cprint("E4a", out)
        out = self.activation(self.conv_4b(out))
        cprint("E4b", out)
        
        # Decoder
        out = self.up_3(out)
        out = torch.cat([out, E3_out], dim=1)
        cprint("up_3", out)
        out = self.activation(self.D_conv_3a(out))
        cprint("D3a", out)
        out = self.activation(self.D_conv_3b(out))
        cprint("D3b", out)
        
        out = self.up_2(out)
        out = torch.cat([out, E2_out], dim=1)
        cprint("up_2", out)
        out = self.activation(self.D_conv_2a(out))
        cprint("D2a", out)
        out = self.activation(self.D_conv_2b(out))
        cprint("D2b", out)
        
        out = self.up_1(out)
        out = torch.cat([out, E1_out], dim=1)
        cprint("up_1", out)
        out = self.activation(self.D_conv_1a(out))
        cprint("D1a", out)
        out = self.activation(self.D_conv_1b(out))
        cprint("D1b", out)
        
        # Time delay
        # ----------------------------------------
        if self.prev_state is not None:
            prev_state = self.prev_state
        else:
            prev_state = torch.zeros_like(out)
        self.prev_state = out.detach()
        out = torch.cat([out, prev_state], dim=1)
        # ----------------------------------------
        
        out = self.final_layer(out)
        cprint("out", out)
        return out

In [14]:
VERBOSE = False

In [15]:
n_classes = 3

In [16]:
model = Unet(3).train()

Building encoder
Building decoder


In [17]:
optimizer = optim.Adam(model.parameters(), lr=0.005)

In [18]:
criterion = nn.CrossEntropyLoss()

In [None]:
losses = []
n_epochs = 10
for epoch_i in range(n_epochs):
    for batch_i, (X, y) in enumerate(loader):
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y.long().squeeze(dim=1))
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        stdout_str = f'\rEpoch {epoch_i+1}/{n_epochs} - '
        stdout_str += f'Batch {batch_i+1}/{len(loader)} '
        stdout_str += f'Avg Loss: {np.mean(losses):0.4f}'
        sys.stdout.write(stdout_str)
        sys.stdout.flush()

Epoch 1/10 - Batch 4/5 Avg Loss: 1.0676

In [None]:
plt.plot(losses)