## Dependencies

In [1]:
import os
import csv
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,random_split
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder

os.chdir('../')

In [2]:
%pwd

'c:\\Projects\\python\\brain_tumor_segmentation'

In [3]:
from src.utils import DiceLoss
from src.model1 import UNet

In [4]:
## Global parameters
H,W = 256,256

## Model Training

### Train function

In [None]:
def train_UNet(model,train_loader,valid_loader,device,num_epochs:int=500,lr:float=1e-4,patience=20,log_path='./logs/train_log.csv'):
    criterion = DiceLoss()
    optimizer = optim.Adam(params=model.parameters(),lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                     mode='min',
                                                     patience=patience,
                                                     verbose=True,
                                                     min_lr=1e-7)
    best_loss = float('inf')
    patience_counter = 0
    model.to(device)

    with open(log_path, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["Epoch", "Train Loss", "Valid Loss"])

    #train loop begins

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        for images,masks in train_loader:
            images,masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs,masks)
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()

        model.eval()
        valid_loss = 0
        with torch.no_grad():
            for images,masks in valid_loader:
                images,masks = images.to(device),masks.to(device)
                outputs = model(images)
                loss = criterion(outputs,masks)
                valid_loss+=loss.item()
        
        scheduler.step(valid_loss)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}")

        with open(log_path, mode='a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch + 1, train_loss, valid_loss])

        if valid_loss<best_loss:
            best_loss=valid_loss
            torch.save(model.state_dict(),'./models/model1/best_model.pth')
            patience_counter = 0
        else:
            patience_counter+=1
        
        if patience_counter>=patience:
            print("Early Stopping triggered!")
            break


### Trainer code

In [None]:
torch.manual_seed(42)

create_dir("files")
batch_size = 16
lr = 1e-4
num_epochs = 500
dataset_path = "/media/nikhil/Seagate Backup Plus Drive/ML_DATASET/brain_tumor_dataset/data"

transform = transforms.Compose([
    transforms.Resize((H, W)),
    transforms.ToTensor()
])

full_dataset = ImageFolder(root=dataset_path, transform=transform)
train_size = int(0.6 * len(full_dataset))
valid_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - valid_size

train_dataset, valid_dataset, test_dataset = random_split(full_dataset, [train_size, valid_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet()
train_model(model, train_loader, valid_loader, device, num_epochs, lr)