In [None]:
# Import libraries

## General
import os
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

## PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as pth_transforms

## Transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [3]:
# Select device which is not used

torch.cuda.empty_cache()
torch.cuda.set_device(3)
torch.cuda.current_device()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
# Train, Valid, Test Data Classes

class PigDataset(Dataset):
    def __init__(self, root_dir, mask_dir, transform=None):
        self.root_dir = root_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted([file for file in os.listdir(self.root_dir) if file.endswith('.png')])
        #self.mask_filenames = sorted([file for file in os.listdir(self.mask_dir) if file.endswith('.npy')])
        self.mask_filenames = [file for file in os.listdir(self.mask_dir) if file.endswith('.npy')]
        random.shuffle(self.mask_filenames)
        self.mask_filenames = self.mask_filenames[:5000]
        self.transform = transform

    def __len__(self):
        return len(self.mask_filenames)

    def __getitem__(self, idx):
        mask_name = os.path.join(self.mask_dir, self.mask_filenames[idx])
        img_name = os.path.join(self.root_dir, self.mask_filenames[idx].replace('.npy', '.png'))

        image = Image.open(img_name).convert('RGB').resize((640, 352))
        mask = np.load(mask_name)

        if self.transform:
            augmented = self.transform(image=np.array(image), mask=np.array(mask))
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

class PigDatasetGenerateMasks(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_filenames = sorted([file for file in os.listdir(self.root_dir) if file.endswith('.png') and not file.startswith('b')])
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_filenames[idx])

        image = Image.open(img_name).convert('RGB').resize((768, 768)) #.resize((1024, 1024))

        if self.transform:
            augmented = self.transform(image=np.array(image))
            image = augmented['image']

        return image, self.image_filenames[idx]
        
# Define Albumentations transforms
transform = A.Compose([A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.5),
                       A.HorizontalFlip(p=0.5),
                       A.VerticalFlip(p=0.5),
                       A.ToFloat(),
                       ToTensorV2()])

transformValid = A.Compose([A.ToFloat(),
                            ToTensorV2()])

# Datasets
dataset = PigDataset(root_dir='../data/pigsRawFull', mask_dir='../data/dinoMasks', transform=transform)
datasetValid = PigDataset(root_dir='../data/pigsRawFull', mask_dir='../data/dinoMasks', transform=transformValid)
datasetGenerate = PigDatasetGenerateMasks(root_dir='../data/pigsRawFull', transform=transformValid)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8)

In [5]:
# UNet Architecture

class UNet(nn.Module):
    def __init__(self, input_channels=1):
        super(UNet, self).__init__()
        self.convP1 = 32
        self.convP2 = 64
        self.convP3 = 128
        self.convP4 = 256       
        self.convP5 = 512
        
        self.conv1 = self.conv_block(input_channels, self.convP1)
        self.conv2 = self.conv_block(self.convP1, self.convP2)
        self.conv3 = self.conv_block(self.convP2, self.convP3)
        self.conv4 = self.conv_block(self.convP3, self.convP4)
        self.conv5 = self.conv_block(self.convP4, self.convP5)
        
        self.convStride1 = nn.Conv2d(self.convP1, self.convP1, kernel_size=3, padding=1, stride=(2,2))
        self.convStride2 = nn.Conv2d(self.convP2, self.convP2, kernel_size=3, padding=1, stride=(2,2))
        self.convStride3 = nn.Conv2d(self.convP3, self.convP3, kernel_size=3, padding=1, stride=(2,2))
        self.convStride4 = nn.Conv2d(self.convP4, self.convP4, kernel_size=3, padding=1, stride=(2,2))

        self.convsingle1 = self.conv_block_single(self.convP5, self.convP4)
        self.convsingle2 = self.conv_block_single(self.convP4, self.convP3)
        self.convsingle3 = self.conv_block_single(self.convP3, self.convP2)
        self.convsingle4 = self.conv_block_single(self.convP2, self.convP1)
        
        self.upconv1 = self.conv_block2(self.convP5, self.convP4)
        self.upconv2 = self.conv_block2(self.convP4, self.convP3)
        self.upconv3 = self.conv_block2(self.convP3, self.convP2)
        self.upconv4 = self.conv_block2(self.convP2, self.convP1)        
        
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.upsample2 = nn.ConvTranspose2d(self.convP5, self.convP5, kernel_size=2, stride=2)
        self.upsample3 = nn.ConvTranspose2d(self.convP4, self.convP4, kernel_size=2, stride=2)
        self.upsample4 = nn.ConvTranspose2d(self.convP3, self.convP3, kernel_size=2, stride=2)
        self.upsample5 = nn.ConvTranspose2d(self.convP2, self.convP2, kernel_size=2, stride=2)

        self.dropout = nn.Dropout2d(p=0.2)
        self.dropoutEncoding = nn.Dropout2d(p=0.05)

        self.final_conv1 = nn.Conv2d(self.convP1, 2, kernel_size=3, padding='same')
        self.final_conv2 = nn.Conv2d(2, 1, kernel_size=1, padding='valid')

        self.initialize_weights() # Initialize weights with He normal initialization

        
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
            nn.ReLU(inplace=True)
        ).to(device)
        
    def conv_block2(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        ).to(device)

    def conv_block_single(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
            nn.ReLU(inplace=True)
        ).to(device)

    def forward(self, x):
        conv1 = self.conv1(x)
        drop1 = self.dropoutEncoding(conv1)
        sdrop1 = self.convStride1(drop1)

        conv2 = self.conv2(sdrop1)
        drop2 = self.dropoutEncoding(conv2)
        sdrop2 = self.convStride2(drop2)

        conv3 = self.conv3(sdrop2)
        drop3 = self.dropoutEncoding(conv3)
        sdrop3 = self.convStride3(drop3)

        conv4 = self.conv4(sdrop3)
        drop4 = self.dropout(conv4)
        sdrop4 = self.convStride4(drop4)

        conv5 = self.conv5(sdrop4)
        drop5 = self.dropout(conv5)
        
        up6 = self.upsample2(drop5)
        up6 = self.convsingle1(up6)
        merge6 = torch.cat([drop4, up6], dim=1)
        conv6 = self.upconv1(merge6)

        up7 = self.upsample3(conv6)
        up7 = self.convsingle2(up7)
        merge7 = torch.cat([drop3, up7], dim=1)
        conv7 = self.upconv2(merge7)
        
        up8 = self.upsample4(conv7)
        up8 = self.convsingle3(up8)
        merge8 = torch.cat([drop2, up8], dim=1)
        conv8 = self.upconv3(merge8)
        
        up9 = self.upsample5(conv8)
        up9 = self.convsingle4(up9)
        merge9 = torch.cat([drop1, up9], dim=1)
        conv9 = self.upconv4(merge9)        

        final_conv = self.final_conv1(conv9)
        final_conv = self.final_conv2(final_conv)
        
        return final_conv

In [6]:
# Single image inference
def getPrediction(imageIndex):
    x, z = datasetGenerate[imageIndex]
    print(x.shape)
    with torch.no_grad():
        predicted_masks = model(x.unsqueeze(0).to(device)).cpu().squeeze()
        predicted_masks = F.interpolate(predicted_masks.unsqueeze(0).unsqueeze(0), size=(360, 640), mode='nearest').squeeze()
        predicted_masks_numpy = predicted_masks.cpu().numpy()
    return predicted_masks, predicted_masks_numpy, z

In [None]:
# Train Network

# Define your UNet model
model = UNet(input_channels=3)
model = model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adamax(model.parameters(), lr=5e-5, weight_decay=1e-4)

# Training loop
num_epochs = 15
iteration = 0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for images, masks in dataloader:
        images, masks = images.to(device).float(), masks.to(device).float()

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)

        # Calculate loss
        loss = criterion(outputs.squeeze(dim=1), masks)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        iteration += 1
        if (iteration) % 75 == 0:
            print(iteration, str(round(running_loss, 4)))

    # Verbose
    print('Epoch ' + str(epoch+1) +  ' - Loss: ' + str(round(running_loss, 4)))
print('Training finished!')

In [20]:
torch.save(model.state_dict(), 'UNetPigsCheckpoint30.pth')