In [None]:
from data_loader_co_reg import *
from model import *
from data_extraction_nii import *
import torch
from torch.utils.data import DataLoader, random_split
from torch import nn, optim
from data_loader import *

In [None]:

target_sequences, moving_sequences, number_of_frames = extract_data(patient_n=2)
# Assuming AugmentationDataLoader is defined and imports are managed
dataset = AugmentationDataLoader(target_sequences, moving_sequences)

# Split the dataset into train and test sets (80% train, 20% test)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
print("train size, val size",train_size, test_size)
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)
for batch in train_dataloader:
    combined_augmented_seq, theta, dx, dy = batch
    
    print(combined_augmented_seq.shape, theta.shape, dx.shape, dy.shape)
    break


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import wandb


# Initialize wandb
wandb.init(project="diff")

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for epoch in range(num_epochs):
        model.train()  # Ensure the model is in training mode
        running_loss = 0.0
        
        for i,batch in enumerate(train_loader):
            # Example batch components
            combined_augmented_seq, theta, dx, dy = batch
            
            # Move each part to the device separately
            combined_augmented_seq = combined_augmented_seq.type(torch.FloatTensor)
            combined_augmented_seq = combined_augmented_seq.to(device)
            
            theta = theta.to(device)
            dx = dx.to(device)
            dy = dy.to(device)
           
            # Generate t-values and theta_noise for the current batch
            batch_size, _, num_frames, _, _ = combined_augmented_seq.size()
            t = (1-torch.rand(batch_size)).to(device)  # t-value sampled uniformly between 0 and 1

            t_expanded = t.view(batch_size, 1).expand(batch_size, theta.shape[-1])  # Expand t to match theta_noise shape
     

            optimizer.zero_grad()

            if epoch ==0:
                t = torch.ones(batch_size, 1).to(device)
                t_expanded = t.view(batch_size, 1).expand(batch_size, theta.shape[-1]) 
                noisy_theta_noise = torch.randn_like(theta).to(device)
                noisy_theta_x_noise = torch.randn_like(dx).to(device)
                noisy_theta_y_noise = torch.randn_like(dy).to(device)
                predictions = model(combined_augmented_seq, t_expanded, noisy_theta_noise, noisy_theta_x_noise, noisy_theta_y_noise)
                #gt = torch.cat((noisy_theta_noise, noisy_theta_x_noise, noisy_theta_y_noise), dim=-1)
                gt = torch.cat((theta, dx, dy), dim=-1)
                loss = criterion(predictions, gt_split)
            else:

                # Add Gaussian noise to theta_noise using the t-value
                predictions_split = torch.split(predictions, 100, dim=1)
                gt_split = torch.split(gt, 100, dim=1)
                noisy_theta_noise = noisy_theta_noise + t_expanded * (predictions[0] - gt[0])
                noisy_theta_x_noise = noisy_theta_x_noise + t_expanded * (predictions[1] - gt[1])
                noisy_theta_y_noise = noisy_theta_y_noise + t_expanded * (predictions[2] - gt[2])
                back_transformer = BackTransformation()
                combined_augmented_seq = back_transformer.apply_back_transform(combined_augmented_seq, predictions_split[0], predictions_split[1], predictions_split[2])
                predictions = model(combined_augmented_seq, t_expanded, noisy_theta_noise, noisy_theta_x_noise, noisy_theta_y_noise)
                #gt = torch.cat((noisy_theta_noise, noisy_theta_x_noise, noisy_theta_y_noise), dim=-1)
                gt = torch.cat((theta, dx, dy), dim=-1)
                loss = criterion(predictions, gt)


            loss.backward()
            optimizer.step()
            running_loss += loss.item()
     
            print(f"Total Loss: {loss.item()}")
            wandb.log({"Training Loss": loss.item()})

                # Validation step
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for val_batch in val_loader:
                combined_augmented_seq, theta, dx, dy = val_batch
                combined_augmented_seq = combined_augmented_seq.to(device).float()
                theta = theta.to(device)
                dx = dx.to(device)
                dy = dy.to(device)
                batch_size, _, num_frames, _, _ = combined_augmented_seq.size()
                t = (1 - torch.rand(batch_size)).to(combined_augmented_seq.device)
                t_expanded = t.view(batch_size, 1).expand(batch_size, theta.shape[-1])
                noisy_theta_noise = t_expanded * torch.randn_like(theta)
                noisy_theta_x_noise = t_expanded * torch.randn_like(dx)
                noisy_theta_y_noise = t_expanded * torch.randn_like(dy)
                
                val_predictions = model(combined_augmented_seq, t_expanded, noisy_theta_noise, noisy_theta_x_noise, noisy_theta_y_noise)
                gt = torch.cat((theta, dx, dy), dim=-1)
                val_loss_total = criterion(val_predictions, gt)
                val_loss += val_loss_total.item()
        
                print(f"Total val Loss: {val_loss_total.item()}")
                wandb.log({"Validation Loss": val_loss_total.item()})


# Initialize the model, loss function, and optimizer
model = CombinedModel()  # Set embedding size to 512
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

# Move the model to the appropriate device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Train the model with validation
train_model(model,train_dataloader ,test_dataloader , criterion, optimizer, num_epochs=10)