In [None]:
import os
import sys
sys.path.append("../")

import numpy as np
import pandas as pd
from PIL import Image

import matplotlib.pyplot as plt

from helper import *
from models import *
from dataset import SolarTrackerDataset

import segmentation_models_pytorch as smp

import torch
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

In [None]:
TRAIN_FRAC = 0.8
VAL_FRAC = 0.1
SPLIT_STRATEGTY = "random" # simple or random
DATA_AUGMENTATION = False

MODEL_ARCHITECTURE = "convnext" # unet or convnext
MODEL_NAME = "convnext" # unet, deeplabv3, convnext
MODEL_ENCODER = "convnext_tiny" # resnet34, resnet50, convnext_tiny 

EPOCHS = 50
CHECKPOINT_NAME = f"../checkpoints/{MODEL_NAME}_data_split_{SPLIT_STRATEGTY}_{MODEL_ENCODER}_V2.pth"

### Data Preparation

In [None]:
metadata = "../metadata/tiles_index.csv"
df = pd.read_csv(metadata)

### Data split strategy

In [None]:
image_paths = df['image_path'].tolist()
label_paths = df['label_path'].tolist()

split = DataSplitStrategy()
if SPLIT_STRATEGTY == "simple":
    train, val, _ = split.simple_split(image_paths=image_paths,
                                    label_paths=label_paths,
                                    train_frac=TRAIN_FRAC,
                                    val_frac=VAL_FRAC)
elif SPLIT_STRATEGTY == "random":
    train, val, _ = split.random_split(image_paths=image_paths,
                                    label_paths=label_paths,
                                    train_frac=TRAIN_FRAC,
                                    val_frac=VAL_FRAC)
else:
    raise ValueError("Invalid split strategy. Choose either 'simple' or 'random'.")

train_image_path, train_label_path = train
val_image_path, val_label_path = val

### Create datalaoders for training

In [None]:
train_dataset = SolarTrackerDataset(img_paths=train_image_path, mask_paths=train_label_path, augmentation=DATA_AUGMENTATION)
val_dataset = SolarTrackerDataset(img_paths=val_image_path, mask_paths=val_label_path, augmentation=False)

train_dataloader = DataLoader(train_dataset, 
                              batch_size=16,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)

val_dataloader = DataLoader(val_dataset,
                            batch_size=16,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)


In [None]:
if MODEL_ARCHITECTURE == "unet":
    model = get_segmentation_model(model_name=MODEL_NAME, model_encoder=MODEL_ENCODER)
else:
    model = ConvNeXtModel(backbone_name="convnext_tiny", pretrained=True, num_classes=1)

model = model.to(device)

In [None]:
dice_loss = smp.losses.DiceLoss(mode='binary')
bce_loss = torch.nn.BCEWithLogitsLoss()

def loss_fn(pred, true):
    return dice_loss(pred, true) + bce_loss(pred, true)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

### Model Training

In [None]:
def train_one_epoch(model, loader):

    model.train()
    running_loss = 0.0

    for idx, images, masks in loader:
        optimizer.zero_grad()
        images = images.permute(0, 3, 1, 2)
        masks = masks.unsqueeze(1) 
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)

        outputs = model(images).to(device)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    return running_loss / len(loader)

def validate(model, loader):
    
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for idx, images, masks in loader:
            images = images.permute(0, 3, 1, 2)
            masks = masks.unsqueeze(1) 

            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            outputs = model(images)
            loss = loss_fn(outputs, masks)

            running_loss += loss.item()
    
    return running_loss / len(loader)


In [None]:
best_val_loss = float('inf')
TRAIN_LOSS_PER_EPOCH = []
VAL_LOSS_PER_EPOCH = []

for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_dataloader)
    val_loss = validate(model, val_dataloader)

    print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"../checkpoints/{CHECKPOINT_NAME}")

    TRAIN_LOSS_PER_EPOCH.append(train_loss)
    VAL_LOSS_PER_EPOCH.append(val_loss)

### Plotting training and val loss 

In [None]:
plt.plot(TRAIN_LOSS_PER_EPOCH, label="Train Loss")
plt.plot(VAL_LOSS_PER_EPOCH, label="Val Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss per Epoch")
plt.legend()
plt.show()