In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import torchvision.transforms.functional as TF
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim


## File Naming Convention

Each `.pt` file should follow this format:

```
<label>_<replicate>.pt
```

**Examples:**
- `150_1.pt` → Label: 150 (undetectable), Replicate: 1
- `500_2.pt` → Label: 500 (low), Replicate: 2
- `7000_3.pt` → Label: 7000 (medium), Replicate: 3
- `20000_5.pt` → Label: 20000 (high), Replicate: 5



## Class Definitions for Semi-Quantitative approach 

Infers clinical decision-making based on viral load counts (assuming 1:1 sample prep
)
1. `undetectable` → Label values `< 200`
2. `low` → Label values `200 ≤ label ≤ 1000`
3. `medium` → Label values `1000 < label ≤ 10000`
4. `high` → Label values `> 10000`

# Dataset Folder Structure

Dataset is organized into the following structure to ensure proper training, validation, and testing:

```
Datasets/
│-- Split/
│   │-- Training/             # Training dataset (60% of total data)
│   │   ├── undetectable/      # Class 0 (e.g., files with labels < 200)
│   │   │   ├── 20_1.pt
│   │   │   ├── 40_3.pt
│   │   │   └── ...
│   │   ├── low/               # Class 1 (200 ≤ label ≤ 1000)
|   |   |   ├── 300_2.pt
│   │   │   ├── 600_4.pt
│   │   │   └── ...
│   │   ├── medium/            # Class 2 (1000 < label ≤ 10000)
│   │   │   ├── 2000_1.pt
│   │   │   ├── 7000_2.pt
│   │   │   └── ...
│   │   ├── high/              # Class 3 (label > 10000)
│   │   │   ├── 10000_2.pt
│   │   │   ├── 90000_2.pt
│   │   │   └── ...
│
│   │-- Validation/            # Validation dataset (20% of total data)
│   │   ├── undetectable/
│   │   │   ├── 20_2.pt
│   │   │   ├── 40_4.pt
│   │   │   └── ...
│   │   ├── low/
│   │   ├── medium/
│   │   ├── high/
│
│   │-- Testing/               # Testing dataset (20% of total data)
│   │   ├── undetectable/
│   │   │   ├── 30_1.pt
│   │   │   ├── 50_2.pt
│   │   │   └── ...
│   │   ├── low/
│   │   ├── medium/
│   │   ├── high/
│
│-- torch_tensors/              # Original .pt files before splitting
│   │   ├── 100_1.pt
│   │   ├── 200_3.pt
│   │   ├── 5000_2.pt
│   │   ├── 15000_4.pt
│   │   └── ...
```

## Folder Descriptions

- **`Training/`** – Used to train the model (60% of total data).
- **`Validation/`** – Used to validate the model during training (20% of total data).
- **`Testing/`** – Used to evaluate the model after training (20% of total data).
- **`torch_tensors/`** – Stores the original `.pt` files before they were split.



In [None]:
# reduced image HxW from 500x500 to 224x224 to inprove speed for intial test 
class PTDataset(Dataset):
    def __init__(self, root_dir, target_size=(224, 224), transform=None):
        """
        Args:
            root_dir (str): Path to the dataset directory (e.g., Training folder).
            target_size (tuple): Desired output size (height, width).
            transform (callable, optional): Optional transformations.
        """
        self.root_dir = root_dir
        self.target_size = target_size
        self.transform = transform
        self.classes = ['undetectable', 'low', 'medium', 'high']
        
        # Collect all file paths and labels
        self.file_list = []
        for label in self.classes:
            class_path = os.path.join(root_dir, label)
            for file in os.listdir(class_path):
                if file.endswith('.pt'):
                    self.file_list.append((os.path.join(class_path, file), self.classes.index(label)))

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path, label = self.file_list[idx]

        # Load the tensor from file
        tensor_data = torch.load(file_path, weights_only=True)  # Shape [C, T, H, W]

        # Ensure the tensor has enough frames
        max_frames = tensor_data.shape[1]  
        # Frame selection with safe index handling
        selected_frame_indices = [69, 89, 109, 129, 149, 179] #Still need to plot average sigmoidal curve for better channels
        selected_frame_indices = [i for i in selected_frame_indices if i < max_frames]

        if len(selected_frame_indices) < 6:
            raise ValueError(f"Not enough frames in {file_path}, available: {max_frames}, required: 180")

        # Extract the required frames safely
        avg_first_20 = torch.mean(tensor_data[:, :20, :, :], dim=1, keepdim=True)  # Shape [C, 1, H, W]
        selected_frames = tensor_data[:, selected_frame_indices, :, :]  # Shape [C, 6, H, W]

        # Concatenate to form a 7-channel tensor
        final_tensor = torch.cat((avg_first_20, selected_frames), dim=1)  # Shape [C, 7, H, W]

        # Convert from [C, 7, H, W] -> [7, H, W] by removing the channel dimension if necessary
        if final_tensor.shape[0] == 1:  # If the first dimension is singleton, remove it
            final_tensor = final_tensor.squeeze(0)  # Shape becomes [7, H, W]
        else:
            final_tensor = final_tensor.squeeze()  # General squeeze to avoid extra dims

        # Resize each channel to the target size (e.g., 224x224)
        resized_tensor = torch.stack([
            TF.resize(final_tensor[i].unsqueeze(0), self.target_size).squeeze(0) for i in range(final_tensor.shape[0])
        ]) 

        # Apply transformations if provided
        if self.transform:
            resized_tensor = self.transform(resized_tensor)

        return resized_tensor, label


In [None]:
# Define paths data
train_dataset_path = 'E:/Datasets/Split/SemiQuant/Training/'
val_dataset_path = 'E:/Datasets/Split/SemiQuant/Validation/'
test_dataset_path = 'E:/Datasets/Split/SemiQuant/Testing/'

# Load datasets
train_dataset = PTDataset(root_dir=train_dataset_path, target_size=(224, 224))
val_dataset = PTDataset(root_dir=val_dataset_path, target_size=(224, 224))
test_dataset = PTDataset(root_dir=test_dataset_path, target_size=(224, 224))

#DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0) #errors out for num_worker >0
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

for batch in train_loader:
    inputs, labels = batch
    print(f"Batch input shape: {inputs.shape}")  # Expected [batch_size, 7, 224, 224]
    print(f"Batch labels: {labels}")
    break


In [None]:
def get_resnet_model(num_classes=4, input_channels=7):
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

device = torch.device("cpu") #no gpu available at this time

model = get_resnet_model(num_classes=4, input_channels=7).to(device)
criterion = nn.CrossEntropyLoss() #Binary only for two catergories?
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct, total = 0, 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)  # Forward pass

            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backpropagation
            optimizer.step()  # Update weights

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total

        # Perform validation after each epoch
        val_loss, val_acc = evaluate_model(model, val_loader, criterion)

        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    print("Done")

def evaluate_model(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct, total = 0, 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = running_loss / len(loader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

In [None]:
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)


In [None]:
test_loss, test_acc = evaluate_model(model, test_loader, criterion)
print(f"Final Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")

torch.save(model.state_dict(), 'resnet_model.pth')
print("Model saved as resnet_model.pth")

model.load_state_dict(torch.load('resnet_model.pth'))
model.eval()
print("Model loaded for inference.")

In [None]:
def predict(model, input_tensor):
    model.eval()
    input_tensor = input_tensor.unsqueeze(0).to(device)  # Add batch dimension
    output = model(input_tensor)
    _, predicted_class = torch.max(output, 1)
    return predicted_class.item()

# Example inference on a test sample
sample_input, _ = test_dataset[0]
prediction = predict(model, sample_input)
print(f"Predicted class: {prediction}")