In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from Affine_Transformations import generate_strain_tensors

In [2]:
# Load the dataset
file_path = "data/DeDoDe_Descriptors_Dataset.pth"  # Change this to your actual path
data = torch.load(file_path)
all_base_descriptors = data['descriptors']# Assuming these keys exist
all_transformed_descriptors = data['deformed_descriptors']

  data = torch.load(file_path)


In [9]:
# Select the data corresponding to transformation x
transformation = 37
transformations = data['transformations']
transformations = torch.repeat_interleave(transformations, repeats=11, dim=1).flatten()
idx = transformations == transformation
base_descriptors = all_base_descriptors[idx]
transformed_descriptors = all_transformed_descriptors[idx]
print(f'Transformation Tensor: {generate_strain_tensors()[transformation]}')
print(f'Number of Descriptions: {len(base_descriptors)}')
print(f'Mean Cosine Similarity: {torch.nn.functional.cosine_similarity(base_descriptors,transformed_descriptors,dim = 1).mean()}')

Transformation Tensor: (-0.25, 0.0, 0.0)
Number of Descriptions: 1023
Mean Cosine Similarity: 0.8963428496344185


In [20]:
# Create a dataset and dataloader
dataset = TensorDataset(base_descriptors, transformed_descriptors)
# dataset = TensorDataset(all_base_descriptors, all_transformed_descriptors)
train_size = 0.8
val_size = 0.2

train_dataset, val_dataset = random_split(dataset,[train_size,val_size])

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last = False)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last = False)

In [21]:
# Define the MLP model
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super(MLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim,output_dim)
            
            # nn.Linear(input_dim, hidden_dim),
            # # nn.BatchNorm1d(hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, output_dim)
            
            # nn.Linear(input_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.model(x)

In [22]:
# Define residual MLP model
class Residual_MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super(Residual_MLP, self).__init__()
        self.model = nn.Sequential(
            # nn.Linear(input_dim,output_dim)
            
            nn.Linear(input_dim, hidden_dim),
            # nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
            
            # nn.Linear(input_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        residual = self.model(x)
        return x + residual

In [23]:
# Define residual conditional MLP


In [24]:
# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
input_dim = base_descriptors.shape[1]
output_dim = transformed_descriptors.shape[1]
# model = MLP(input_dim,output_dim).double().to(device)
model = Residual_MLP(input_dim,output_dim,hidden_dim=128).double().to(device)

cuda


In [25]:
# Potentially apply identity initialisation to model
identity_initialisation = True

def init_identity(m):
    if isinstance(m, nn.Linear):
        nn.init.eye_(m.weight)  # Initialize as identity
        if m.bias is not None:
            nn.init.zeros_(m.bias)

if identity_initialisation and isinstance(model,MLP):
    model.apply(init_identity)
    print('identity applied')


In [26]:
# Define loss function and optimizer
def cosine_similarity(output,target,reduction = 'mean'):
    loss = 1 - torch.nn.functional.cosine_similarity(output,target)

    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'none':
        return loss

# criterion = nn.CosineEmbeddingLoss(reduction='none')
# criterion = nn.MSELoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

In [27]:
# One eval round before training
model.eval()
start_losses = []
with torch.no_grad():
    for base, transformed in val_dataloader:
        base, transformed = base.to(device), transformed.to(device)
        output = model(base)
        # losses = criterion(output,transformed)
        losses = cosine_similarity(output,transformed,'none')
        loss = losses.mean()
        start_losses.extend(losses.cpu().numpy())

start_loss = np.mean(start_losses)
print(f'Loss before training: {start_loss}')

writer = SummaryWriter(log_dir="runs/tain_one_transformation")

# Training loop
epochs = 250
for epoch in range(epochs):
    model.train()
    train_losses = []
    for base, transformed in train_dataloader:
        base, transformed = base.to(device), transformed.to(device)
        
        optimizer.zero_grad()
        output = model(base)
        # losses = criterion(output, transformed)
        losses = cosine_similarity(output,transformed,'none')
        loss = losses.mean()
        loss.backward()
        optimizer.step()
        
        train_losses.extend(losses.cpu().detach().numpy())

    model.eval()
    val_losses = []
    with torch.no_grad():
        for base, transformed in val_dataloader:
            base, transformed = base.to(device), transformed.to(device)
            output = model(base)
            # losses = criterion(output,transformed)
            losses = cosine_similarity(output,transformed,'none')
            val_losses.extend(losses.cpu().numpy())

    avg_train_loss = np.mean(train_losses)
    std_train_loss = np.std(train_losses)
    
    avg_val_loss = np.mean(val_losses)
    std_val_loss = np.std(val_losses)

    writer.add_scalar("Loss/Train", avg_train_loss, epoch)
    writer.add_scalar("Loss_std/Train", std_train_loss, epoch)
    writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
    writer.add_scalar("Loss_std/Validation", std_val_loss, epoch)

    scheduler.step(avg_val_loss)

    if (epoch+1)%10 == 0:
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")

Loss before training: 0.13808156929684137
Epoch 10/250, Train Loss: 0.105026, Val Loss: 0.114934
Epoch 20/250, Train Loss: 0.097171, Val Loss: 0.109487
Epoch 30/250, Train Loss: 0.091474, Val Loss: 0.105748
Epoch 40/250, Train Loss: 0.086177, Val Loss: 0.102478
Epoch 50/250, Train Loss: 0.081191, Val Loss: 0.099558
Epoch 60/250, Train Loss: 0.076648, Val Loss: 0.096947
Epoch 70/250, Train Loss: 0.072645, Val Loss: 0.094752
Epoch 80/250, Train Loss: 0.069174, Val Loss: 0.092904
Epoch 90/250, Train Loss: 0.066151, Val Loss: 0.091374
Epoch 100/250, Train Loss: 0.063492, Val Loss: 0.090114
Epoch 110/250, Train Loss: 0.061149, Val Loss: 0.089025
Epoch 120/250, Train Loss: 0.059060, Val Loss: 0.088129
Epoch 130/250, Train Loss: 0.057179, Val Loss: 0.087374
Epoch 140/250, Train Loss: 0.055484, Val Loss: 0.086730
Epoch 150/250, Train Loss: 0.053933, Val Loss: 0.086208
Epoch 160/250, Train Loss: 0.052521, Val Loss: 0.085734
Epoch 170/250, Train Loss: 0.051232, Val Loss: 0.085314
Epoch 180/250, 

In [28]:
# Save the trained model
torch.save(model.state_dict(), "models/single_transformation_model.pth")
print("Model training complete and saved.")

Model training complete and saved.
