In [None]:
# Create Virtual Environment for model
!python3 -m venv myenv

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define dataset paths
dataset_path = '../../data'
train_path = os.path.join(dataset_path, "train")
val_path = os.path.join(dataset_path, "validation")
test_path = os.path.join(dataset_path, "test")

# Set a seed for reproducibility
seed = 42
torch.manual_seed(seed)

In [None]:
# Apply FFT to each image in the data pipeline
def apply_fft(image):
    image_tensor = transforms.ToTensor()(image).unsqueeze(0)  # Convert to tensor and add batch dimension
    fft_image = torch.fft.fftshift(torch.fft.fft2(image_tensor))  # Perform FFT and shift zero frequency to center
    fft_image = torch.abs(fft_image)  # Take magnitude
    fft_image = torch.log(fft_image + 1e-5)  # Avoid log(0) by adding a small constant
    
    # Ensure that the transformed image tensor requires gradients
    fft_image.requires_grad = True
    
    return fft_image.squeeze(0)  # Remove batch dimension

# Custom transformation to apply FFT in the data pipeline
class FFTTransform:
    def __call__(self, image):
        return apply_fft(image)

In [None]:
# Define transforms for training and validation with FFT
transformation_for_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    FFTTransform(),  # Apply FFT
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Directly normalize the tensor
])

transformation_for_valntest = transforms.Compose([
    transforms.Resize((224, 224)),
    FFTTransform(),  # Apply FFT
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Directly normalize the tensor
])

In [None]:
# Load datasets
print(train_path)
train_dataset = datasets.ImageFolder(root=train_path, transform=transformation_for_train)
val_dataset = datasets.ImageFolder(root=val_path, transform=transformation_for_valntest)
test_dataset = datasets.ImageFolder(root=test_path, transform=transformation_for_valntest)

# DataLoader with batch size
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Function to calculate recall and accuracy
def compute_metrics(outputs, labels):
    # Convert the logits to binary predictions
    predicted = (torch.sigmoid(outputs) > 0.5).float()  # Predictions as 0 or 1
    
    # True positives, false positives, false negatives, true negatives
    tp = torch.sum((predicted == 1) & (labels == 1)).item()  # True positives
    fp = torch.sum((predicted == 1) & (labels == 0)).item()  # False positives
    fn = torch.sum((predicted == 0) & (labels == 1)).item()  # False negatives
    tn = torch.sum((predicted == 0) & (labels == 0)).item()  # True negatives
    
    # Accuracy
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    
    # Precision
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0  # Avoid division by zero
    
    # Recall
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0  # Avoid division by zero
    
    # F1-Score
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0  # Avoid division by zero
    
    return accuracy, recall, precision, f1


# Freezing and unfreezing model layers
def freeze_everything_except_classifier(model):
    # Get the parameters of the classifier for comparison
    classifier_params = set(model.classifier.parameters())
    
    # Freeze all parameters except the classifier
    for param in model.parameters():
        if param not in classifier_params:
            param.requires_grad = False
    
    print("Only training classifier")


def unfreeze_last_block(model):
    for name, params in model.named_parameters():
        if "layer4" in name or "fc" in name:
            params.requires_grad = True
        else:
            params.requires_grad = False
    print("Training last block and classifier")


def unfreeze_last_two_blocks(model):
    for name, params in model.named_parameters():
        if "layer3" in name or "layer4" in name or "fc" in name:
            params.requires_grad = True
        else:
            params.requires_grad = False
    print("Training last 2 blocks and classifier")


def unfreeze_whole_model(model):
    for params in model.parameters():
        params.requires_grad = True
    print("Whole model training")

In [None]:
densenetmodel= models.densenet121(pretrained=True)
no_features= densenetmodel.classifier.in_features
densenetmodel.classifier = nn.Linear(no_features,1)
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
densenetmodel = densenetmodel.to(device)

In [None]:
# Optimizer and loss function
criterion = nn.BCEWithLogitsLoss()

phases = [
    {"epochs": 10, "unfreeze": freeze_everything_except_classifier, "lr": 1e-3},
    {"epochs": 10, "unfreeze": unfreeze_last_block, "lr": 1e-4},
    {"epochs": 10, "unfreeze": unfreeze_last_two_blocks, "lr": 1e-5},
    {"epochs": 10, "unfreeze": unfreeze_whole_model, "lr": 1e-6}
]


In [None]:
# Training loop
with open("training_log.txt", "w") as log_file:
    for phase_idx, phase in enumerate(phases):
        densenetmodel= models.densenet121(pretrained=True)
        no_features= densenetmodel.classifier.in_features
        densenetmodel.classifier = nn.Linear(no_features,1)
        device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
        densenetmodel = densenetmodel.to(device)
        densenetmodel.train()
        
        phase["unfreeze"](densenetmodel)
        optimiser = optim.Adam(densenetmodel.parameters(), lr=phase["lr"])
        log_file.write(f"Starting phase {phase_idx + 1}: {phase['unfreeze'].__name__} | Learning Rate: {phase['lr']}\n")

        for epoch in range(phase["epochs"]):
            densenetmodel.train()
            running_loss = 0.0
            correct = 0
            total = 0
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.float().to(device)

                optimiser.zero_grad()
                outputs = densenetmodel(inputs).squeeze()  # Get model output
                loss = criterion(outputs, labels)  # Compute loss
                loss.backward()
                optimiser.step()

                running_loss += loss.item()

                accuracy, recall, precision, f1 = compute_metrics(outputs, labels)
                
                # Track metrics
                correct += accuracy
                total += 1

            epoch_loss = running_loss / len(train_loader)
            epoch_accuracy = correct / total
            log_file.write(f"Phase {phase_idx + 1}: {phase['unfreeze'].__name__}, Epoch [{epoch + 1}/{phase['epochs']}], "
                           f"Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}\n")

            print(f"Phase {phase_idx + 1}: Epoch [{epoch + 1}/{phase['epochs']}], Loss: {epoch_loss:.4f}, "
                  f"Accuracy: {epoch_accuracy:.4f}")

            checkpoint_filename = f"../models/fftdensenet_phase{phase_idx+1}epoch_{epoch + 1}.pth"
            torch.save(densenetmodel.state_dict(), checkpoint_filename)
            print(f"Model saved as fftdensenet_phase{phase_idx+1}epoch_{epoch + 1}.pth")

            # Validation phase
            densenetmodel.eval()
            correct = 0
            total = 0
            running_loss = 0.0
            true_positive = 0
            false_positive = 0
            false_negative = 0
            true_negative = 0
            all_preds = []
            all_labels = []

            with torch.no_grad():
                for inputs, labels in test_loader:
                    inputs, labels = inputs.to(device), labels.float().unsqueeze(1).to(device)
                    # Get model predictions
                    outputs = densenetmodel(inputs).squeeze()  # Get model output
                    # Ensure the labels have the same shape as the model output
                    labels = labels.view(-1)

                    # Compute loss
                    loss = criterion(outputs, labels)
                    running_loss += loss.item()

                    # Compute metrics manually
                    predicted = (torch.sigmoid(outputs) > 0.5).float()  # Convert to binary predictions


                    total += labels.size(0)

                    # Update counts for recall and precision
                    true_positive += ((predicted == 1) & (labels == 1)).sum().item()
                    false_positive += ((predicted == 1) & (labels == 0)).sum().item()
                    false_negative += ((predicted == 0) & (labels == 1)).sum().item()
                    true_negative += ((predicted == 0) & (labels == 0)).sum().item()



                    all_preds.extend(predicted.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())


            correct = true_positive+true_negative

            # Compute metrics manually
            avg_loss = running_loss / len(test_loader)
            avg_accuracy = correct / total

            # Recall = true_positive / (true_positive + false_negative)
            avg_recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0

            # Precision = true_positive / (true_positive + false_positive)
            avg_precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0

            # F1 Score = 2 * (precision * recall) / (precision + recall)
            avg_f1 = 2 * (avg_precision * avg_recall) / (avg_precision + avg_recall) if (avg_precision + avg_recall) > 0 else 0


            # Write test results to the log file
            log_file.write(f"Validation Loss: {avg_loss:.4f}\n")
            log_file.write(f"Validation Accuracy: {avg_accuracy:.4f}\n")
            log_file.write(f"Validation Precision: {avg_precision:.4f}\n")
            log_file.write(f"Validation Recall: {avg_recall:.4f}\n")
            log_file.write(f"Validation F1 Score: {avg_f1:.4f}\n")
            log_file.write("=" * 50 + "\n")  # Separator for clarity