In [None]:
import torch
import torch.nn as nn
from torchvision import models
from torch.nn.functional import relu

Below is a simple UNet from documentation used for training segmentation of MRI:

Work Cited: https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/

In [None]:
from collections import OrderedDict

class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features

        # ._block is defined below, shortening the need for repeating Conv, Norm, and Activation calls.

        # Encoder Block
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bridge/Bottleneck
        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        # Decoder Layer
        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

The above code is long and tedious for defining a UNet layer by layer; thus can be simplified into funtions below:

In [None]:
def conv_block(input, num_filters):
    s = nn.Conv2d(num_filters, kernel_size=3, padding = 1)(input)
    s = nn.BatchNorm2d(num_filters)
    s = nn.ReLU()(s)

    s = nn.Conv2d(num_filters, kernel_size=3, padding = 1)(s)
    s = nn.BatchNorm2d(num_filters)
    s = nn.ReLU()(s)

    return s

def encoder_block(input, num_filters):
    s = conv_block(input, num_filters)
    p = nn.MaxPool2d(kernel_size=2, stride=2)(s)

    return s, p

def decoder_block(input, num_filters, skip_connections):
    d = nn.ConvTranspose2d(num_filters, kernel_size=2, stride=2)(input)
    d = torch.cat([d, skip_connections])
    d = conv_block(input, num_filters)

    return d


def create_UNet(input_shape, n_classes):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024)

    d1 = decoder_block(b1, 1024, s4)
    d2 = decoder_block(d1, 512, s3)
    d3 = decoder_block(d2, 256, s2)
    d4 = decoder_block(d3, 128, s1)

    output = nn.Conv2d(64, n_classes, kernel_size=1)

    model = 

    return model

In [None]:
# Dataset
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label



images = []
masks = []

train_dataset = CustomImageDataset()
val_dataset = CustomImageDataset()

# Dataloader
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 64)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size = 6)

def train():
    model = create_UNet()
    optimizer = optim.Adam(model.parameters(), lr = 0.001)
    cec = nn.CrossEntropyLoss()
    accuracies = []
    max_accuracy = 0

    device = next(model.parameters()).device

    num_epochs = 10

    for epoch in range(num_epochs):
        # Train Mode
        model.train()

        for i, (images, labels) in enumerate(train_dataloader):
            images = images.to(device)
            labels = labels.to(device)

            #No gradient descent
            optimizer.zero_grad()

            #Calculate Loss
            pred = model(images)
            loss = cec(pred, labels)

            #Backpropagation
            loss.backward()

            #Adjust & optimize model parameters
            optimizer.step()


        model.eval()
        running_loss = 0.0

        with torch.no_grad():
            for inputs, targets in val_dataloader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                running_loss += loss.item() * inputs.size(0)
                accuracy = 

        avg_loss = running_loss / len(dataloader.dataset)
        print(f'Validation Loss: {avg_loss:.4f}')
    
        accuracies.append(accuracy)
        
        # Find best model
        if (accuracy > max_accuracy):
            max_accuracy = accuracy
            best_model = copy.deepcopy(model)
            print(f'Epoch: {epoch + 1}, Accuracy: {accuracy}')




#Validation
def val_loop(model, data):
    model.eval() #switch from training to validation mode
    total = 0
    correct = 0

    #Declaring the device
    device = next(model.parameters()).device

    with torch.no_grad(): #No gradient descent needed
        for images, labels in data:
            #Puts each image on CUDA
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()


    return 100 * correct / total