In [None]:
import json

train_jsonl_path = '/srv/data/lt2326-h25/lt2326-h25/a1/train.jsonl'

annotations_data = {}

with open(train_jsonl_path, 'r') as f:
    for line in f:

        image_annotations = json.loads(line)

        image_id = image_annotations ['image_id']

        annotations_data[image_id] = image_annotations
    

In [None]:

info_json_path = '/srv/data/lt2326-h25/lt2326-h25/a1/info.json'

print(f"Loading master list from: {info_json_path}")


with open(info_json_path, 'r') as f:
    info_data = json.load(f)


official_training_list = info_data['train']


official_training_filenames = {img['file_name'] for img in official_training_list}


print(f"Successfully found and processed {len(official_training_filenames)} official training file entries.")


print(f"Example filename: {list(official_training_filenames)[0]}")



In [None]:
import os
import random



images_dir_path = '/srv/data/lt2326-h25/lt2326-h25/a1/images'


available_image_filenames = {f for f in os.listdir(images_dir_path) if f.endswith('.jpg')}


usable_filenames_set = official_training_filenames & available_image_filenames

usable_filenames = list(usable_filenames_set)

print(f"\nThere are a total of {len(usable_filenames)} Chinese character images to split.")
random.shuffle(usable_filenames)

train_ratio = 0.8
val_ratio = 0.1

total_count = len(usable_filenames)
train_end_index = int(total_count * train_ratio)
val_end_index = int(total_count * (train_ratio + val_ratio))

train_files = usable_filenames[:train_end_index]
val_files = usable_filenames[train_end_index:val_end_index]
test_files = usable_filenames[val_end_index:]

print(f"Training set:   {len(train_files)} files")
print(f"Validation set: {len(val_files)} files")
print(f"Test set:       {len(test_files)} files")

In [None]:
def structure_dataset(file_list, annotations_lookup, images_base_path):
    
    structured_data = []
    for filename in file_list:
        
        image_id = os.path.splitext(filename)[0]
        
        
        if image_id in annotations_lookup:
            structured_data.append({
                'image_path': os.path.join(images_base_path, filename),
                'annotations': annotations_lookup[image_id]
            })
    return structured_data


train_dataset = structure_dataset(train_files, annotations_data, images_dir_path)
val_dataset = structure_dataset(val_files, annotations_data, images_dir_path)
test_dataset = structure_dataset(test_files, annotations_data, images_dir_path)

print(f" The 'train_dataset' was created with {len(train_dataset)} pics.")
print(f" The 'val_dataset'was created with {len(val_dataset)} pics.")
print(f" The 'test_dataset' was created with {len(test_dataset)} pics.")


In [None]:
import torch
from torch.utils.data import Dataset
import cv2
import numpy as np
import torchvision.transforms as T
import math
from tqdm import tqdm

class ChineseCharacterDataset(Dataset):
    def __init__(self, dataset_list, annotations_data, window_size=256, stride=128, transform=None):
        self.dataset_list = dataset_list
        self.annotations_data = annotations_data
        self.window_size = window_size
        self.stride = stride
        self.transform = transform
        
        
        self.cache = {}
        
        self.windows = []
        print("Pre-calculating sliding windows")
        for item in tqdm(self.dataset_list):
            h = item['annotations']['height']
            w = item['annotations']['width']
            
            num_x = math.ceil((w - self.window_size) / self.stride) + 1 if w > self.window_size else 1
            num_y = math.ceil((h - self.window_size) / self.stride) + 1 if h > self.window_size else 1
            
            for i in range(num_y):
                for j in range(num_x):
                    y = i * self.stride
                    x = j * self.stride
                    self.windows.append({'image_path': item['image_path'], 'x': x, 'y': y})
        
        print(f"--- Created {len(self.windows)} total windows from {len(self.dataset_list)} images. ---")

    def __len__(self):
        return len(self.windows)

    def __getitem__(self, idx):
        window_info = self.windows[idx]
        image_path = window_info['image_path']
        
        if image_path in self.cache:
            image, mask = self.cache[image_path]
        else:
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
            image_id = image_path.split('/')[-1].replace('.jpg', '')
            if image_id in self.annotations_data:
                for char_list in self.annotations_data[image_id]['annotations']:
                    if isinstance(char_list, list) and len(char_list) > 0:
                        for char_annotation in char_list:
                            if char_annotation.get('is_chinese', False):
                                polygon = np.array(char_annotation['polygon'], dtype=np.int32).reshape(-1, 2)
                                cv2.fillPoly(mask, [polygon], 1)
            
            self.cache[image_path] = (image, mask)

        x, y = window_info['x'], window_info['y']
        h, w, _ = image.shape
        
        x_end = min(x + self.window_size, w)
        y_end = min(y + self.window_size, h)
        
        image_patch = image[y:y_end, x:x_end]
        mask_patch = mask[y:y_end, x:x_end]

        pad_x = self.window_size - image_patch.shape[1]
        pad_y = self.window_size - image_patch.shape[0]
        if pad_x > 0 or pad_y > 0:
            image_patch = np.pad(image_patch, ((0, pad_y), (0, pad_x), (0, 0)), mode='constant')
            mask_patch = np.pad(mask_patch, ((0, pad_y), (0, pad_x)), mode='constant')

        image_tensor = T.ToTensor()(image_patch)
        mask_tensor = torch.from_numpy(mask_patch.copy()).unsqueeze(0).float()
            
        if self.transform:
            image_tensor = self.transform(image_tensor)
            
        return image_tensor, mask_tensor

In [None]:
def calculate_mean_std(dataset_list, image_dir, resize_shape=(256, 256)):
    """
    Calculates the mean and standard deviation of the training dataset for normalization.
    This prevents data leakage from the validation/test sets.
    """
    
    count = len(dataset_list)
    mean = torch.empty(3, dtype=torch.float64)
    std = torch.empty(3, dtype=torch.float64)
    
    temp_dataset = ChineseCharacterDataset(dataset_list)
    loader = DataLoader(temp_dataset, batch_size=64, shuffle=False)

    sum_of_pixels = torch.zeros(3)
    sum_of_squares = torch.zeros(3)
    num_pixels = 0

    for images, _ in tqdm(loader, desc="Calculating Stats"):
        # images shape is (batch, channels, height, width)
        batch_size, channels, height, width = images.shape
        num_pixels += batch_size * height * width
        
        sum_of_pixels += torch.sum(images, dim=[0, 2, 3])
        sum_of_squares += torch.sum(images ** 2, dim=[0, 2, 3])

    mean = sum_of_pixels / num_pixels
    std = torch.sqrt((sum_of_squares / num_pixels) - mean ** 2)
    
    print(f"Calculation Complete.\nMean: {mean}\nStd: {std}")
    return mean.tolist(), std.tolist()

In [None]:
import torch.nn as nn

class SimpleSegmentationModel(nn.Module):
    
    
    def __init__(self):
      
        super().__init__()
        
       
        self.encoder = nn.Sequential(
            
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), 
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) 
        )
        
        
        self.upsampler = nn.Sequential(
            
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            nn.ReLU(),
            
           
            nn.ConvTranspose2d(16, 16, kernel_size=2, stride=2),
            nn.ReLU()
        )
        
        
        self.final_conv = nn.Conv2d(16, 1, kernel_size=1)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        
       
        encoded_features = self.encoder(x)
        
       
        upsampled_features = self.upsampler(encoded_features)
        
        
        logits = self.final_conv(upsampled_features)
        
        
        output_mask = self.sigmoid(logits)
        
        return output_mask

In [None]:

from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim

print("Preparing for Simple Model Training")

# Hyperparameters
NUM_EPOCHS = 15        
LEARNING_RATE = 0.001
BATCH_SIZE = 32
IMAGE_SIZE = (512, 512) 

device = "cuda" if torch.cuda.is_available() else "cpu"

train_data_simple = ChineseCharacterDataset(train_dataset,annotations_data)
val_data_simple = ChineseCharacterDataset(val_dataset,annotations_data)

train_loader_simple = DataLoader(train_data_simple, batch_size=BATCH_SIZE, shuffle=True)
val_loader_simple = DataLoader(val_data_simple, batch_size=BATCH_SIZE)

def calculate_pos_weight(dataloader):
    print("Calculating pos_weight for BCE loss")
    num_pos = 0
    num_neg = 0
    for _, masks in tqdm(dataloader, desc="Calculating pos_weight"):
        num_pos += torch.sum(masks == 1)
        num_neg += torch.sum(masks == 0)
    return num_neg / num_pos

pos_weight = calculate_pos_weight(train_loader_simple)
print(f"Positive weight calculated: {pos_weight:.2f}")

print("Initializing Model, Loss Function, and Optimizer")
model_1 = SimpleSegmentationModel().to(device)

loss_function = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight).to(device))
optimizer = optim.Adam(model_1.parameters(), lr=LEARNING_RATE)

In [None]:

train_loss_history = []
print(f"Starting training for {NUM_EPOCHS} epochs...")
for epoch in range(NUM_EPOCHS):
    model_1.train() 
    running_train_loss = 0.0
    train_progress_bar = tqdm(train_loader_simple, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]")
    
    for images, masks in train_progress_bar:
        images = images.to(device)
        masks = masks.to(device)
        outputs = model_1(images)
        loss = loss_function(outputs, masks)
        optimizer.zero_grad() 
        loss.backward()       
        optimizer.step()     
        running_train_loss += loss.item()
        train_progress_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_train_loss = running_train_loss / len(train_loader_simple)
    train_loss_history.append(avg_train_loss)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] - Average Training Loss: {avg_train_loss:.4f}")

print("\nTraining finished")
MODEL_SAVE_PATH = "simple_segmentation_model_improved.pth"
torch.save(model_1.state_dict(), MODEL_SAVE_PATH)
print(f"\nModel 1 has been trained and saved to {MODEL_SAVE_PATH}")

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F

 # inspired from https://www.digitalocean.com/community/tutorials/writing-resnet-from-scratch-in-pytorch 
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class UNetResNetDeepSupervision(nn.Module):
    def __init__(self, n_classes=1):
        super().__init__()
        
        self.base_model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        
        self.encoder_layers = list(self.base_model.children())
        self.layer0 = nn.Sequential(*self.encoder_layers[:3])
        self.layer1 = nn.Sequential(*self.encoder_layers[3:5])
        self.layer2 = self.encoder_layers[5]
        self.layer3 = self.encoder_layers[6]
        self.layer4 = self.encoder_layers[7]

        # Decoder blocks
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec_conv3 = self._make_decoder_block(256 + 256, 256)
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec_conv2 = self._make_decoder_block(128 + 128, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec_conv1 = self._make_decoder_block(64 + 64, 64)
        
        self.upconv0 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec_conv0 = self._make_decoder_block(32 + 64, 32)
        
        self.final_upconv = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        
        self.out_final = nn.Conv2d(16, n_classes, kernel_size=1)
        self.out_1 = nn.Conv2d(64, n_classes, kernel_size=1)
        self.out_2 = nn.Conv2d(128, n_classes, kernel_size=1)
        self.out_3 = nn.Conv2d(256, n_classes, kernel_size=1)

    def _make_decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)
        )

    def forward(self, x):
        e0 = self.layer0(x); e1 = self.layer1(e0)
        e2 = self.layer2(e1); e3 = self.layer3(e2); e4 = self.layer4(e3)

        d3 = self.upconv3(e4); d3 = torch.cat([d3, e3], dim=1); d3 = self.dec_conv3(d3)
        d2 = self.upconv2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.dec_conv2(d2)
        d1 = self.upconv1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.dec_conv1(d1)
        d0 = self.upconv0(d1); d0 = torch.cat([d0, e0], dim=1); d0 = self.dec_conv0(d0)
        
        final_features = self.final_upconv(d0)
        
        out_final = self.out_final(final_features)
        
        if self.training:
            out1 = self.out_1(d1)
            out2 = self.out_2(d2)
            out3 = self.out_3(d3)
            return out_final, out1, out2, out3
        else:
            return out_final

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DiceLoss(nn.Module):
    """
    Dice Loss for image segmentation
    """
    def __init__(self, smooth=1.):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        probs_flat = probs.view(-1)
        targets_flat = targets.view(-1)
        intersection = (probs_flat * targets_flat).sum()
        dice_coefficient = (2. * intersection + self.smooth) / (probs_flat.sum() + targets_flat.sum() + self.smooth)
        return 1 - dice_coefficient

class FocalLoss(nn.Module):
    """
    Focal Loss, to address extreme class imbalance
    """
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        probs = torch.sigmoid(logits)
        pt = probs * targets + (1 - probs) * (1 - targets)
        focal_term = (1.0 - pt).pow(self.gamma)
        
        alpha_term = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        
        loss = alpha_term * focal_term * bce_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class FocalDiceLoss(nn.Module):
    """
    A combined loss that uses Focal Loss for pixel-wise classification
    and Dice Loss for improving segmentation boundary definitions
    """
    def __init__(self, alpha=0.25, gamma=2.0, smooth=1.):
        super(FocalDiceLoss, self).__init__()
        self.focal = FocalLoss(alpha=alpha, gamma=gamma)
        self.dice = DiceLoss(smooth=smooth)

    def forward(self, logits, targets):
        return self.focal(logits, targets) + self.dice(logits, targets)

class DeepSupervisionLoss(nn.Module):
    """
    Calculates a weighted loss from the multiple outputs of a deeply supervised model.
    """
    def __init__(self, base_loss_fn, weights=None):
        super(DeepSupervisionLoss, self).__init__()
        self.base_loss_fn = base_loss_fn
        self.weights = weights if weights is not None else [0.5 / (2**i) for i in range(10)]
        self.weights.insert(0, 1.0)

    def forward(self, outputs, masks):
        
        total_loss = self.weights[0] * self.base_loss_fn(outputs[0], masks)
        
        for i in range(1, len(outputs)):
            downsampled_mask = F.interpolate(masks, size=outputs[i].shape[2:], mode='nearest')
            
            auxiliary_loss = self.weights[i] * self.base_loss_fn(outputs[i], downsampled_mask)
            total_loss += auxiliary_loss
            
        return total_loss

In [None]:
import torch
from tqdm import tqdm

def train_model_epoch_ds(model, dataloader, optimizer, loss_function, device):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(dataloader, desc="Training Epoch")
    for images, masks in progress_bar:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = loss_function(outputs, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")
    return running_loss / len(dataloader)

def validate_model_epoch_ds(model, dataloader, loss_function, device):
    """
    Performs one full epoch of validation for the deep supervision model.
    """
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            
            outputs = model(images)
            
            loss = loss_function(outputs, masks)
            
            running_loss += loss.item()
            
    return running_loss / len(dataloader)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as T
import gc

# Hyperparameters
FREEZE_EPOCHS = 1
UNFREEZE_EPOCHS = 2
TOTAL_EPOCHS = FREEZE_EPOCHS + UNFREEZE_EPOCHS
LEARNING_RATE_DECODER = 0.001
LEARNING_RATE_FINETUNE = 0.00005
IMAGE_SIZE = (768, 768)
BATCH_SIZE = 16
EARLY_STOPPING_PATIENCE = 5

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"--- Training Deep Supervision model on: {device} ---")

# Official ImageNet statistics
imagenet_mean = [0.485, 0.456, 0.0406]
imagenet_std = [0.229, 0.224, 0.225]

train_transform = T.Compose([
    T.RandomHorizontalFlip(), T.RandomRotation(10),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    T.Normalize(mean=imagenet_mean, std=imagenet_std)
])
val_test_transform = T.Compose([T.Normalize(mean=imagenet_mean, std=imagenet_std)])
train_data = ChineseCharacterDataset(train_dataset, annotations_data,transform=train_transform,)
val_data = ChineseCharacterDataset(val_dataset, annotations_data, transform=val_test_transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)

final_model = UNetResNetDeepSupervision().to(device) 
base_loss = FocalDiceLoss().to(device)
loss_fn = DeepSupervisionLoss(base_loss_fn=base_loss).to(device)

try:
    print(f"\n--- STAGE 1: Freezing encoder and training decoder for {FREEZE_EPOCHS} epochs ---")
    for param in final_model.base_model.parameters():
        param.requires_grad = False

    optimizer = optim.AdamW(final_model.parameters(), lr=LEARNING_RATE_DECODER)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=3, verbose=True)

    best_val_loss = np.Inf
    epochs_no_improve = 0
    MODEL_SAVE_PATH = "deep_supervision_model_best.pth"

    for epoch in range(FREEZE_EPOCHS):
        avg_train_loss = train_model_epoch_ds(final_model, train_loader, optimizer, loss_fn, device)
        avg_val_loss = validate_model_epoch_ds(final_model, val_loader, base_loss, device) # Note: val uses base_loss
        scheduler.step(avg_val_loss)
        print(f"Epoch [{epoch+1}/{TOTAL_EPOCHS}] -> Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(final_model.state_dict(), MODEL_SAVE_PATH)
            print(f"  -> Checkpoint saved! New best validation loss: {best_val_loss:.4f}")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            print(f"  -> Validation loss did not improve for {epochs_no_improve} epoch(s).")
        
        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"\n--- Early stopping triggered after {EARLY_STOPPING_PATIENCE} epochs with no improvement. ---")
            break

    print(f"\n--- STAGE 2: Unfreezing encoder and fine-tuning all layers for {UNFREEZE_EPOCHS} epochs ---")
    for param in final_model.base_model.parameters():
        param.requires_grad = True

    final_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    optimizer = optim.AdamW(final_model.parameters(), lr=LEARNING_RATE_FINETUNE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=3, verbose=True)
    epochs_no_improve = 0 

    for epoch in range(UNFREEZE_EPOCHS):
        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            break
            
        avg_train_loss = train_model_epoch_ds(final_model, train_loader, optimizer, loss_fn, device)
        avg_val_loss = validate_model_epoch_ds(final_model, val_loader, base_loss, device)
        scheduler.step(avg_val_loss)
        print(f"Epoch [{epoch+1+FREEZE_EPOCHS}/{TOTAL_EPOCHS}] -> Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(final_model.state_dict(), MODEL_SAVE_PATH)
            print(f"  -> Checkpoint saved! New best validation loss: {best_val_loss:.4f}")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            print(f"  -> Validation loss did not improve for {epochs_no_improve} epoch(s).")
        
        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"\n--- Early stopping triggered after {EARLY_STOPPING_PATIENCE} epochs with no improvement. ---")
            break

    print("\n--- Training Finished ---")
    print(f"The best model was saved to {MODEL_SAVE_PATH} with a final validation loss of {best_val_loss:.4f}")

finally:
    # Cleanup block
    print("\n--- Finalizing session and cleaning up CUDA memory... ---")
    try:
        del final_model, train_loader, val_loader, train_data, val_data
    except NameError:
        print("Some objects were not defined, skipping deletion.")
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("--- Cleanup complete. ---")

In [None]:
def find_optimal_threshold(model, dataloader, device):
    """
    Iterates through a range of thresholds to find the one that maximizes the F1-score.
    """
    model.eval()
    
    all_preds_flat = []
    all_masks_flat = []
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Gathering predictions for threshold tuning"):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            preds = torch.sigmoid(outputs)
            
            all_preds_flat.append(preds.view(-1))
            all_masks_flat.append(masks.view(-1))

    all_preds_flat = torch.cat(all_preds_flat)
    all_masks_flat = torch.cat(all_masks_flat)

    best_f1 = 0
    best_threshold = 0
    
    for threshold in np.arange(0.1, 0.9, 0.05):
        preds_binary = (all_preds_flat > threshold).float()
        
        tp = (preds_binary * all_masks_flat).sum().item()
        fp = (preds_binary * (1 - all_masks_flat)).sum().item()
        fn = ((1 - preds_binary) * all_masks_flat).sum().item()
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
            
    print(f"Optimal threshold found: {best_threshold:.2f} (with F1-score: {best_f1:.4f})")
    return best_threshold

In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as T
import cv2
import math


def visualize_sw_comparison(simple_model, sw_model, dataset_list, device, num_samples=3, simple_model_size=(512, 512), window_size=256, stride=128, threshold=0.5):
    """
    Visualize a comparison between the simple model (full image) and the 
    advanced model (sliding window inference).
    """
    simple_model.eval()
    sw_model.eval()

    fig, ax = plt.subplots(num_samples, 4, figsize=(20, 5 * num_samples))
    fig.suptitle('Simple Model vs. Sliding Window Model Comparison', fontsize=20, y=1.02)
    
    sample_items = random.sample(dataset_list, num_samples)
    
    for i, item in enumerate(sample_items):
        image_path = item['image_path']
        image_id = image_path.split('/')[-1].replace('.jpg', '')
        
        original_image = cv2.imread(image_path)
        original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
        h, w, _ = original_image.shape
        
        gt_mask_full = np.zeros((h, w), dtype=np.uint8)
        if image_id in annotations_data:
            for char_list in annotations_data[image_id]['annotations']:
                if isinstance(char_list, list) and len(char_list) > 0:
                    for char_annotation in char_list:
                        if char_annotation.get('is_chinese', False):
                            polygon = np.array(char_annotation['polygon'], dtype=np.int32).reshape(-1, 2)
                            cv2.fillPoly(gt_mask_full, [polygon], 1)

        resized_for_simple = cv2.resize(original_image, simple_model_size)
        simple_tensor = T.ToTensor()(resized_for_simple).unsqueeze(0).to(device)
        with torch.no_grad():
            pred_simple_raw = simple_model(simple_tensor)
            pred_simple_prob = torch.sigmoid(pred_simple_raw).squeeze().cpu().numpy()
        pred_simple_full = cv2.resize(pred_simple_prob, (w, h))

        final_sw_mask = np.zeros((h, w), dtype=np.float32)
        counts_mask = np.zeros((h, w), dtype=np.float32)

        for y in range(0, h, stride):
            for x in range(0, w, stride):
                x_end, y_end = min(x + window_size, w), min(y + window_size, h)
                image_patch = original_image[y:y_end, x:x_end]
                
                pad_x = window_size - image_patch.shape[1]
                pad_y = window_size - image_patch.shape[0]
                if pad_x > 0 or pad_y > 0:
                    image_patch = np.pad(image_patch, ((0, pad_y), (0, pad_x), (0, 0)), mode='constant')

                patch_tensor = T.ToTensor()(image_patch).unsqueeze(0).to(device)
                
                with torch.no_grad():
                    pred_patch_raw = sw_model(patch_tensor)
                    pred_patch_prob = torch.sigmoid(pred_patch_raw).squeeze().cpu().numpy()
                
                pred_patch_prob = pred_patch_prob[:y_end-y, :x_end-x]
                final_sw_mask[y:y_end, x:x_end] += pred_patch_prob
                counts_mask[y:y_end, x:x_end] += 1
        
        final_sw_mask /= np.maximum(counts_mask, 1)

        ax[i, 0].imshow(original_image); ax[i, 0].set_title(f"Sample {i+1}: Original"); ax[i, 0].axis('off')
        ax[i, 1].imshow(gt_mask_full, cmap='gray'); ax[i, 1].set_title("Ground Truth"); ax[i, 1].axis('off')
        ax[i, 2].imshow(pred_simple_full > threshold, cmap='gray'); ax[i, 2].set_title("Simple Model Pred."); ax[i, 2].axis('off')
        ax[i, 3].imshow(final_sw_mask > threshold, cmap='gray'); ax[i, 3].set_title("Sliding Window Pred."); ax[i, 3].axis('off')

    plt.tight_layout()
    plt.show()

def calculate_metrics_simple(simple_model, dataset_list, device, resize_shape=(512, 512), threshold=0.5):
    simple_model.eval()
    total_tp, total_fp, total_fn, total_tn = 0, 0, 0, 0

    for item in tqdm(dataset_list, desc="Metrics for Simple Model"):
        original_image = cv2.imread(item['image_path'])
        h, w, _ = original_image.shape
        gt_mask_full = np.zeros((h, w), dtype=np.uint8)
        image_id = item['image_path'].split('/')[-1].replace('.jpg', '')
        if image_id in annotations_data:
            for char_list in annotations_data[image_id]['annotations']:
                if isinstance(char_list, list) and len(char_list) > 0:
                    for char_annotation in char_list:
                        if char_annotation.get('is_chinese', False):
                            polygon = np.array(char_annotation['polygon'], dtype=np.int32).reshape(-1, 2)
                            cv2.fillPoly(gt_mask_full, [polygon], 1)
        
        resized_image = cv2.resize(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB), resize_shape)
        tensor_image = T.ToTensor()(resized_image).unsqueeze(0).to(device)
        with torch.no_grad():
            output = simple_model(tensor_image)
            probs = torch.sigmoid(output).squeeze().cpu().numpy()
        
        pred_full_size = cv2.resize(probs, (w, h))
        pred_binary = (pred_full_size > threshold).astype(np.uint8)

        total_tp += (pred_binary * gt_mask_full).sum()
        total_fp += (pred_binary * (1 - gt_mask_full)).sum()
        total_fn += ((1 - pred_binary) * gt_mask_full).sum()
        total_tn += ((1 - pred_binary) * (1 - gt_mask_full)).sum()
        
    accuracy = (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn) if (total_tp + total_tn + total_fp + total_fn) > 0 else 0
    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return accuracy, precision, recall, f1_score

def calculate_metrics_sw(sw_model, dataset_list, device, window_size=256, stride=128, threshold=0.5):
    sw_model.eval()
    total_tp, total_fp, total_fn, total_tn = 0, 0, 0, 0

    for item in tqdm(dataset_list, desc="Metrics for Sliding Window Model"):
        original_image = cv2.imread(item['image_path'])
        h, w, _ = original_image.shape
        gt_mask_full = np.zeros((h, w), dtype=np.uint8)
        image_id = item['image_path'].split('/')[-1].replace('.jpg', '')
        if image_id in annotations_data:
            for char_list in annotations_data[image_id]['annotations']:
                if isinstance(char_list, list) and len(char_list) > 0:
                    for char_annotation in char_list:
                        if char_annotation.get('is_chinese', False):
                            polygon = np.array(char_annotation['polygon'], dtype=np.int32).reshape(-1, 2)
                            cv2.fillPoly(gt_mask_full, [polygon], 1)
        
        final_sw_mask = np.zeros((h, w), dtype=np.float32)
        counts_mask = np.zeros((h, w), dtype=np.float32)
        rgb_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

        for y in range(0, h, stride):
            for x in range(0, w, stride):
                x_end, y_end = min(x + window_size, w), min(y + window_size, h)
                image_patch = rgb_image[y:y_end, x:x_end]
                
                pad_x = window_size - image_patch.shape[1]
                pad_y = window_size - image_patch.shape[0]
                if pad_x > 0 or pad_y > 0:
                    image_patch = np.pad(image_patch, ((0, pad_y), (0, pad_x), (0, 0)), mode='constant')

                patch_tensor = T.ToTensor()(image_patch).unsqueeze(0).to(device)
                with torch.no_grad():
                    pred_patch_raw = sw_model(patch_tensor)
                    pred_patch_prob = torch.sigmoid(pred_patch_raw).squeeze().cpu().numpy()
                
                pred_patch_prob = pred_patch_prob[:y_end-y, :x_end-x]
                final_sw_mask[y:y_end, x:x_end] += pred_patch_prob
                counts_mask[y:y_end, x:x_end] += 1
        
        final_sw_mask /= np.maximum(counts_mask, 1)
        pred_binary = (final_sw_mask > threshold).astype(np.uint8)

        total_tp += (pred_binary * gt_mask_full).sum()
        total_fp += (pred_binary * (1 - gt_mask_full)).sum()
        total_fn += ((1 - pred_binary) * gt_mask_full).sum()
        total_tn += ((1 - pred_binary) * (1 - gt_mask_full)).sum()

    accuracy = (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn) if (total_tp + total_tn + total_fp + total_fn) > 0 else 0
    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return accuracy, precision, recall, f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Running Final Evaluation on device: {device} ---")

SIMPLE_MODEL_SIZE = (512, 512)
SW_WINDOW_SIZE = 256
SW_STRIDE = 128

loaded_simple_model = None
loaded_final_model = None

try:
    print("Loading SimpleSegmentationModel...")
    loaded_simple_model = SimpleSegmentationModel().to(device)
    simple_model_path = "simple_segmentation_model_improved.pth"
    loaded_simple_model.load_state_dict(torch.load(simple_model_path, map_location=device))
    print(f"-> Successfully loaded model from: {simple_model_path}")
except Exception as e:
    print(f"Could not load simple model: {e}")

try:
    print("\nLoading Sliding Window U-Net Model...")
    loaded_final_model = UNetResNetDeepSupervision().to(device)
    final_model_path = "deep_supervision_model_best.pth"
    loaded_final_model.load_state_dict(torch.load(final_model_path, map_location=device))
    print(f"-> Successfully loaded model from: {final_model_path}")
except Exception as e:
    print(f"Could not load final model: {e}")

if loaded_simple_model and loaded_final_model:
    print("\n" + "="*50)
    print("Visual Comparison of Model Predictions")
    print("="*50)
    
    visualize_sw_comparison(
        loaded_simple_model, 
        loaded_final_model, 
        test_dataset, 
        device=device,
        simple_model_size=SIMPLE_MODEL_SIZE,
        window_size=SW_WINDOW_SIZE,
        stride=SW_STRIDE
    )

    print("\n" + "="*50)
    print("Quantitative Metrics Comparison on the Full Test Set")
    print("="*50)

    acc1, pre1, rec1, f1_1 = calculate_metrics_simple(
        loaded_simple_model, 
        test_dataset, 
        device,
        resize_shape=SIMPLE_MODEL_SIZE
    )
    print(f"\nResults for SimpleSegmentationModel:")
    print(f"  Accuracy:  {acc1:.4f}")
    print(f"  Precision: {pre1:.4f}")
    print(f"  Recall:    {rec1:.4f}")
    print(f"  F1-Score:  {f1_1:.4f}")

    acc2, pre2, rec2, f1_2 = calculate_metrics_sw(
        loaded_final_model, 
        test_dataset, 
        device, 
        window_size=SW_WINDOW_SIZE,
        stride=SW_STRIDE
    )
    print(f"\nResults for ResNet model with Sliding Window technique")
    print(f"  Accuracy:  {acc2:.4f}")
    print(f"  Precision: {pre2:.4f}")
    print(f"  Recall:    {rec2:.4f}")
    print(f"  F1-Score:  {f1_2:.4f}")

else:
    print("\nEvaluation skipped because one or both models could not be loaded.")