In [1]:
import torch
from torchvision import transforms
import yaml
from Pneumonia_predictor import PneumoniaDataset, PneumoniaPredictorCNN
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from torch.utils.tensorboard.writer import SummaryWriter
from datetime import datetime
import logging
from helper_funcs import init_writer
from Trainer import Trainer
from pathlib import Path
from torchsummary import summary


In [2]:
logging.basicConfig(level=logging.INFO)

# Load the configuration file
with open("output/model/6_smaller_img_bs32_lr0.0001_epoch15_img_size224x224_config.yaml", "r") as file:
    config = yaml.safe_load(file)

seed = config["seed"]
torch.manual_seed(seed)
np.random.seed(seed)

model_output_dir = Path(config["model_output_dir"])
model_output_dir.mkdir(parents=True, exist_ok=True)


In [3]:
# load data
train_data_df = pd.read_csv(f'{config["data_dir"]}/train_data.csv')
val_data_df = pd.read_csv(f'{config["data_dir"]}/val_data.csv')
test_data_df = pd.read_csv(f'{config["data_dir"]}/test_data.csv')
print(f"Train data shape: {train_data_df.shape} \nValidation data shape: {val_data_df.shape} \nTest data shape: {test_data_df.shape}")


Train data shape: (4185, 3) 
Validation data shape: (1047, 3) 
Test data shape: (624, 3)


In [4]:
train_data_df = pd.concat([train_data_df, val_data_df], ignore_index=True)
# randomise encoded label order
label_map_df = train_data_df[['encoded_label', 'label']].drop_duplicates()
# Convert to a dictionary: {encoded_value: text_label}
# e.g., {1: 'PNEUMONIA', 0: 'NORMAL'}
encoding_to_text_map = pd.Series(label_map_df.label.values, index=label_map_df.encoded_label).to_dict()

original_encoded_labels = train_data_df['encoded_label'].to_numpy(copy=True)

# 3. Shuffle these labels.
np.random.shuffle(original_encoded_labels) 

# 4. Assign the shuffled labels back to the 'encoded_label' column.
train_data_df['encoded_label'] = original_encoded_labels

# 5. Update the 'label' (text) column to match the new 'encoded_label'.
train_data_df['label'] = train_data_df['encoded_label'].map(encoding_to_text_map)

print("\nRandomized label counts (should be same as original):")
print(train_data_df['encoded_label'].value_counts())


Randomized label counts (should be same as original):
encoded_label
1    3883
0    1349
Name: count, dtype: int64


In [5]:
image_size = config['image_size']
train_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=5),

    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

val_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [6]:
train_dataset = PneumoniaDataset(train_data_df, transform=train_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers= config['num_workers'])

val_dataset = PneumoniaDataset(val_data_df, transform=val_transforms)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers= config['num_workers'])

test_dataset = PneumoniaDataset(test_data_df, transform=val_transforms)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers= config['num_workers'])

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")

# load model parameters
model_params = {'image_size':config['image_size'],
                'in_channels':config['in_channels'], 
                'conv_defs':config['conv_layers'], 
                'fc_defs':config['fc_layers'],
                'fc_dropout':config['fc_dropout'],
                'fc_batch_norm':config['fc_batch_norm']
                }

model = PneumoniaPredictorCNN(**model_params)
model.to(device)



# load training parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
loss_fn = torch.nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                    mode='min',
                                                    factor=config['factor'],
                                                    patience=config['patience'],
                                                    cooldown=config['cooldown'],
                                                    min_lr=1e-7,
                                                    )



In [None]:

timestamp = datetime.now().strftime('%d_%m_%H%M%S')
model_identifier = (
        f"6_random_labels"
        f"_bs{config['batch_size']}"
        f"_lr{config['learning_rate']}"
        f"_epoch{config['epochs']}"
        f"_img_size{config['image_size'][0]}x{config['image_size'][1]}"
    )

run_identifier = f"{timestamp}_{model_identifier}"
config["run_identifier"] = run_identifier

final_model_path = model_output_dir / f"{model_identifier}_final.pth"
best_model_path = model_output_dir / f"{model_identifier}_best.pth"


writer = init_writer(config)
logging.info(f"Starting training run: {model_identifier}")
trainer = Trainer(model, optimizer, loss_fn, scheduler, config, device, writer, logging)
trainer.train(train_loader, val_loader, config["epochs"])

# Save final model 
torch.save(model.state_dict(), final_model_path)
logging.info(f"Final model saved to {final_model_path}")

# Rename best model saved by trainer
internal_best_path = model_output_dir / "best_model.pth"
if internal_best_path.exists() and not config['full_trainset']: # Only rename if validation was done
    internal_best_path.rename(best_model_path)
    logging.info(f"Best validation model renamed to {best_model_path}")
elif internal_best_path.exists():
    internal_best_path.unlink() # Clean up intermediate file if no validation

# Save config used for run
config_path = model_output_dir / f"{model_identifier}_config.yaml"
with open(config_path, "w") as f:
    yaml.dump(config, f)
logging.info(f"Configuration saved to {config_path}")

