# Segmentation Model

In [1]:
# Prerequisites
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from segmentation_dataset import SegmentationDataset
import segmentation_models_pytorch as smp 



  from .autonotebook import tqdm as notebook_tqdm


### Check if GPU is available

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using ", DEVICE)

Using  cuda


### Set Hyperparameters

In [3]:
NR_EPOCHS = 50
BATCH_SIZE = 4

### Set 'train' and 'val' Datasets & Dataloaders

In [4]:
train_dataset = SegmentationDataset(path_name='train')
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataset = SegmentationDataset(path_name='val')
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

### Define Model

In [5]:
model = smp.FPN( #  Feature Pyramid Network
    encoder_name = "se_resnext50_32x4d", # Encoder - ResNeXt-50
    encoder_weights="imagenet", # Pretrained weights for encoder
    classes=6, # Number of output classes
    activation="sigmoid" #
) 

model.to(DEVICE)   # Move model to GPU if available

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Adam optimizer

criterion = nn.CrossEntropyLoss() # Cross-entropy loss
# criterion = smp.losses.DiceLoss(mode='multiclass') # Dice loss

### Training Loop

In [6]:
train_losses = []
val_losses = []

for epoch in range(NR_EPOCHS):

    # Training Phase
    model.train()
    running_train_loss = 0
    running_val_loss = 0

    for i, data in enumerate(train_dataloader):
        image_i, mask_i = data
        image = image_i.to(DEVICE)
        mask = mask_i.to(DEVICE)

        # Reset gradients
        optimizer.zero_grad()

        # Forward Pass
        output = model(image.float())

        # Calculate losses
        train_loss = criterion(output.float(), mask.long())

        # Back propagation
        train_loss.backward()
        optimizer.step()  # Update weights

        running_train_loss += train_loss.item()

    train_losses.append(running_train_loss)

    # Validation Phase
    model.eval()
    for i, data in enumerate(val_dataloader):
        image_i, mask_i = data
        image = image_i.to(DEVICE)
        mask = mask_i.to(DEVICE)

        # Forward Pass
        output = model(image.float())

        # Calculate losses
        val_loss = criterion(output.float(), mask.long())
        running_val_loss += val_loss.item()

    val_losses.append(running_val_loss)

    print(f"Epoch: {epoch}: Train Loss: {np.median(running_train_loss)}, Val Loss: {np.median(running_val_loss)}")


Epoch: 0: Train Loss: 213.1577844619751, Val Loss: 10.892298579216003
Epoch: 1: Train Loss: 192.37237656116486, Val Loss: 10.668810367584229


KeyboardInterrupt: 

### Visualize

In [None]:
sns.lineplot(x = range(len(train_losses)), y= train_losses).set(title='Train Loss')
plt.show()
sns.lineplot(x = range(len(train_losses)), y= val_losses).set(title='Validation Loss')
plt.show()

### Save Model

In [None]:
torch.save(model.state_dict(), f'saved_models/FPN_epochs_{NR_EPOCHS}_crossentropy_state_dict.pth')