In [None]:
from data_loader_co_reg import *
from model import *
from data_extraction_nii import *
import torch
from torch.utils.data import DataLoader, random_split
from torch import nn, optim

In [None]:

target_sequences, moving_sequences, number_of_frames = extract_data(patient_n=2)
# Assuming AugmentationDataLoader is defined and imports are managed
dataset = AugmentationDataLoader(target_sequences, moving_sequences)

# Split the dataset into train and test sets (80% train, 20% test)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
print("train size, val size",train_size, test_size)
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)
for batch in train_dataloader:
    combined_augmented_seq, theta, dx, dy = batch
    
    print(combined_augmented_seq.shape, theta.shape, dx.shape, dy.shape)
    break


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import wandb


# Initialize wandb
wandb.init(project="diff")

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for epoch in range(num_epochs):
        model.train()  # Ensure the model is in training mode
        running_loss = 0.0
        
        for batch in train_loader:
            # Example batch components
            combined_augmented_seq, theta, dx, dy = batch
            # Move each part to the device separately
            combined_augmented_seq = combined_augmented_seq.to(device)
            combined_augmented_seq = combined_augmented_seq.type(torch.FloatTensor)
            theta = theta.to(device)
            dx = dx.to(device)
            dy = dy.to(device)
           
            # Generate t-values and theta_noise for the current batch
            batch_size, _, num_frames, _, _ = combined_augmented_seq.size()
            t = (1-torch.rand(batch_size)).to(combined_augmented_seq.device)  # t-value sampled uniformly between 0 and 1

            t_expanded = t.view(batch_size, 1).expand(batch_size, theta.shape[-1])  # Expand t to match theta_noise shape
     
            # Add Gaussian noise to theta_noise using the t-value
            noisy_theta_noise = theta + t_expanded * torch.randn_like(theta)
            noisy_theta_x_noise = dx + t_expanded * torch.randn_like(dx)
            noisy_theta_y_noise = dy + t_expanded * torch.randn_like(dy)
            optimizer.zero_grad()
            predictions = model(combined_augmented_seq, t_expanded, noisy_theta_noise, noisy_theta_x_noise, noisy_theta_y_noise)
            gt = torch.cat((theta, dx, dy), dim=-1)
            predictions_split = torch.split(predictions, 100, dim=1)
            gt_split = torch.split(gt, 100, dim=1)

            # Ensure we have exactly three parts
            assert len(predictions_split) == 3
            assert len(gt_split) == 3

            # Compute losses for each part
            losses = []
            for i in range(3):
         
                loss = criterion(predictions_split[i], gt_split[i])

           
                losses.append(loss)

            # Calculate total loss
            total_loss = sum(losses)/3
            total_loss.backward()
            optimizer.step()
            running_loss += total_loss.item()
            # Print the individual and total losses
            for i, loss in enumerate(losses):
                print(f"Loss {i+1}: {loss.item()}")
            print(f"Total Loss: {total_loss.item()}")
            wandb.log({"Training Loss": total_loss.item(), "loss theta": losses[0].item(), "loss dx": losses[1].item(), "loss dy": losses[2].item()})

                # Validation step
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for val_batch in test_dataloader:
                combined_augmented_seq, theta, dx, dy = val_batch
                combined_augmented_seq = combined_augmented_seq.to(device).float()
                theta = theta.to(device)
                dx = dx.to(device)
                dy = dy.to(device)
                batch_size, _, num_frames, _, _ = combined_augmented_seq.size()
                t = (1 - torch.rand(batch_size)).to(combined_augmented_seq.device)
                t_expanded = t.view(batch_size, 1).expand(batch_size, theta.shape[-1])
                noisy_theta_noise = theta + t_expanded * torch.randn_like(theta)
                noisy_theta_x_noise = dx + t_expanded * torch.randn_like(dx)
                noisy_theta_y_noise = dy + t_expanded * torch.randn_like(dy)
                
                val_predictions = model(combined_augmented_seq, t_expanded, noisy_theta_noise, noisy_theta_x_noise, noisy_theta_y_noise)
                gt = torch.cat((theta, dx, dy), dim=-1)
            
                val_predictions_split = torch.split(val_predictions, 100, dim=1)
                gt_split = torch.split(gt, 100, dim=1)
                
                # Compute losses for each part
                val_loss_parts = []
                for i in range(3):
                    val_loss = criterion(val_predictions_split[i], gt_split[i])
                    val_loss_parts.append(val_loss)
        
                val_loss_total = sum(val_loss_parts) / 3
                val_loss += val_loss_total.item()
                for i, loss in enumerate(val_loss_parts):
                    print(f"Loss {i+1}: {loss.item()}")
        
                print(f"Total val Loss: {val_loss_total.item()}")
                wandb.log({"Validation Loss": val_loss_total.item(), "loss theta": val_loss_parts[0].item(), "loss dx": val_loss_parts[1].item(), "loss dy": val_loss_parts[2].item()})


# Initialize the model, loss function, and optimizer
model = CombinedModel()  # Set embedding size to 512
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

# Move the model to the appropriate device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Train the model with validation
train_model(model,train_dataloader ,test_dataloader , criterion, optimizer, num_epochs=10)