In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import os
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

## File Naming Convention

Each `.pt` file should follow this format:

```
<label>_<replicate>.pt
```

**Examples:**
- `150_1.pt` → Label: 150 (undetectable), Replicate: 1
- `500_2.pt` → Label: 500 (low), Replicate: 2
- `7000_3.pt` → Label: 7000 (medium), Replicate: 3
- `20000_5.pt` → Label: 20000 (high), Replicate: 5



## Class Definitions for Semi-Quantitative approach 

Infers clinical decision-making based on viral load counts (assuming 1:1 sample prep
)
1. `undetectable` → Label values `< 200`
2. `low` → Label values `200 ≤ label ≤ 1000`
3. `medium` → Label values `1000 < label ≤ 10000`
4. `high` → Label values `> 10000`

# Dataset Folder Structure

Dataset is organized into the following structure to ensure proper training, validation, and testing:

```
Datasets/
│-- SemiQuant/
│   │-- Training/             # Training dataset (60% of total data)
│   │   ├── undetectable/      # Class 0 (e.g., files with labels < 200)
│   │   │   ├── 20_1.pt
│   │   │   ├── 40_3.pt
│   │   │   └── ...
│   │   ├── low/               # Class 1 (200 ≤ label ≤ 1000)
|   |   |   ├── 300_2.pt
│   │   │   ├── 600_4.pt
│   │   │   └── ...
│   │   ├── medium/            # Class 2 (1000 < label ≤ 10000)
│   │   │   ├── 2000_1.pt
│   │   │   ├── 7000_2.pt
│   │   │   └── ...
│   │   ├── high/              # Class 3 (label > 10000)
│   │   │   ├── 10000_2.pt
│   │   │   ├── 90000_2.pt
│   │   │   └── ...
│
│   │-- Validation/            # Validation dataset (20% of total data)
│   │   ├── undetectable/
│   │   │   ├── 20_2.pt
│   │   │   ├── 40_4.pt
│   │   │   └── ...
│   │   ├── low/
│   │   ├── medium/
│   │   ├── high/
│
│   │-- Testing/               # Testing dataset (20% of total data)
│   │   ├── undetectable/
│   │   │   ├── 30_1.pt
│   │   │   ├── 50_2.pt
│   │   │   └── ...
│   │   ├── low/
│   │   ├── medium/
│   │   ├── high/
│
│-- torch_tensors/              # Original .pt files before splitting
│   │   ├── 100_1.pt
│   │   ├── 200_3.pt
│   │   ├── 5000_2.pt
│   │   ├── 15000_4.pt
│   │   └── ...
```

## Folder Descriptions

- **`Training/`** – Used to train the model (60% of total data).
- **`Validation/`** – Used to validate the model during training (20% of total data).
- **`Testing/`** – Used to evaluate the model after training (20% of total data).
- **`torch_tensors/`** – Stores the original `.pt` files before they were split.



In [None]:
class PTDataset(Dataset):
    def __init__(self, root_dir, target_size=(500, 500), transform=None):
        """
        Args:
            root_dir (str): Path to the dataset directory (e.g., Training folder).
            target_size (tuple): Desired output size (height, width).
            transform (callable, optional): Optional transformations (on CPU).
        """
        self.root_dir = root_dir
        self.target_size = target_size
        self.transform = transform
        self.classes = ['undetectable', 'low', 'medium', 'high']

        # Collect all file paths and labels
        self.file_list = []
        for label in self.classes:
            class_path = os.path.join(root_dir, label)
            if not os.path.exists(class_path):
                continue  # Skip if folder doesn't exist
            for file in os.listdir(class_path):
                if file.endswith('.pt'):
                    full_path = os.path.join(class_path, file)
                    class_index = self.classes.index(label)
                    self.file_list.append((full_path, class_index))

        # Pre-load everything into memory (CPU)
        self.data_list = []
        for file_path, label in self.file_list:
            # Load from disk to CPU memory
            tensor_data = torch.load(file_path, map_location='cpu')  # [C, T, H, W]

            # Ensure enough frames
            max_frames = tensor_data.shape[1]
            selected_frame_indices = [69, 89, 109, 129, 149, 179]
            selected_frame_indices = [i for i in selected_frame_indices if i < max_frames]
            if len(selected_frame_indices) < 6:
                raise ValueError(f"Not enough frames in {file_path}, available: {max_frames}, required: 180")

            # Compute average of the first 20 frames
            avg_first_20 = torch.mean(tensor_data[:, :20, :, :], dim=1, keepdim=True)  # [C, 1, H, W]
            selected_frames = tensor_data[:, selected_frame_indices, :, :]             # [C, 6, H, W]

            # Concatenate to form a 7-frame tensor
            final_tensor = torch.cat((avg_first_20, selected_frames), dim=1)  # [C, 7, H, W]
            final_tensor = final_tensor.squeeze(0) if final_tensor.shape[0] == 1 else final_tensor

            # Resize on CPU
            if final_tensor.dim() == 3:
                # shape [7, H, W]
                final_tensor = final_tensor.unsqueeze(0)  # -> [1, 7, H, W]

            resized_tensor = F.interpolate(
                final_tensor,
                size=self.target_size,
                mode='bilinear',
                align_corners=False
            )

            # Optional transform
            if self.transform:
                resized_tensor = self.transform(resized_tensor)

            # Model expects input_channels=7, flatten [C=1, frames=7, H, W] -> [7, H, W]
            if resized_tensor.shape[0] == 1:
                resized_tensor = resized_tensor.squeeze(0)  # shape [7, H, W]

            # Store (tensor, label)
            self.data_list.append((resized_tensor, label))

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        return self.data_list[idx]

In [None]:
def get_resnet_model(num_classes=4, input_channels=7):
    """
    Build ResNet34 with a custom first conv layer
    that expects `input_channels`.
    """
    model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
    # Replace first conv to match your input_channels
    model.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
    # Replace FC layer to match num_classes
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=10):
    """
    Basic training routine using CrossEntropyLoss
    for single-label, multi-class classification.
    """
    scaler = torch.amp.GradScaler()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct, total = 0, 0

        for inputs, labels in train_loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()

            with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = 100.0 * correct / total

        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

        scheduler.step()

        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    print("Training complete.")


def evaluate_model(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct, total = 0, 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = running_loss / total
    avg_acc = 100.0 * correct / total
    return avg_loss, avg_acc

def plot_confusion_matrix(model, loader, device, class_names):
    """
    Generates and displays a confusion matrix for the model on the given loader.
    """
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    conf_matrix = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.show()

In [None]:
def main():
    # 1. Check device for cuda or cpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 2. Define dataset paths
    train_dataset_path = 'M:/Datasets/int_split/Training/'
    val_dataset_path = 'M:/Datasets/int_split/Validation/'
    test_dataset_path = 'M:/Datasets/int_split/Testing/'

    # 3. Load datasets
    train_dataset = PTDataset(root_dir=train_dataset_path, target_size=(500, 500))
    val_dataset = PTDataset(root_dir=val_dataset_path, target_size=(500, 500))
    test_dataset = PTDataset(root_dir=test_dataset_path, target_size=(500, 500))

    # 4. Create WeightedRandomSampler for training
    train_labels = [label for _, label in train_dataset.data_list]
    class_counts = Counter(train_labels)
    weights = [1.0 / class_counts[label] for label in train_labels]

    train_sampler = WeightedRandomSampler(
        weights=weights,
        num_samples=len(weights),
        replacement=True
    )

    # 5. Create DataLoaders
    use_pin_memory = True if device.type == 'cuda' else False

    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        sampler=train_sampler,
        shuffle=False,
        num_workers=0,
        pin_memory=use_pin_memory
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0,
        pin_memory=use_pin_memory
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0,
        pin_memory=use_pin_memory
    )

    # 6. Build model, move to device
    model = get_resnet_model(num_classes=4, input_channels=7)
    model.to(device)

    # 7. Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)


    # 8. Train
    train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=13)

    # 9. Evaluate on test set
    test_loss, test_acc = evaluate_model(model, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")

    # 10. Save model
    torch.save(model.state_dict(), 'resnet_model_gamma_13.pth')
    print("Model saved as resnet_model.pth")


if __name__ == "__main__":
    main()

#     # If you want to plot the confusion matrix after everything is done:
#     # (Just call the function from anywhere outside `main()`.)
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     class_names = ['Und.', 'Low', 'Med', 'High']

#     # Re-initialize your model, load state dict if needed, etc.
#     model = get_resnet_model(num_classes=4, input_channels=7)
#     model.load_state_dict(torch.load('resnet_model_gamma2.pth', map_location=device))
#     model.to(device)

    # # Re-create test_loader (or pass it around as a global var) # ran out of ram for this to be in the funct
    # test_dataset_path = 'M:/Datasets/int_split/Testing/'
    # test_dataset = PTDataset(root_dir=test_dataset_path, target_size=(500, 500))
    # test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

    # Plot confusion matrix
    # plot_confusion_matrix(model, test_loader, device, class_names)

In [None]:
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0
)

# Plot confusion matrix
plot_confusion_matrix(model, test_loader, device, class_names)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Ensure the model is in evaluation mode
model.eval()

# Move model to appropriate device (CPU/GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Collect all true labels and predictions
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in test_loader:  
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs)

        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy()) 
        all_labels.extend(labels.cpu().numpy())  

conf_matrix = confusion_matrix(all_labels, all_preds)

class_names = ['Und.', 'Low', 'Med', 'High']  


disp = ConfusionMatrixDisplay(conf_matrix, display_labels=class_names)
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()


In [None]:
print(train_dataset[7][1])
plt.imshow(train_dataset[7][0][-1,:,:])