In [190]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from Affine_Transformations import generate_strain_tensors

In [191]:
# Load the dataset
file_path = "data/DeDoDe_Descriptors_Dataset.pth"  # Change this to your actual path
data = torch.load(file_path)
all_base_descriptors = data['descriptors']  # Assuming these keys exist
all_transformed_descriptors = data['deformed_descriptors']

  data = torch.load(file_path)


In [192]:
# Select the data corresponding to transformation x
transformation = 0
transformations = data['transformations']
transformations = torch.repeat_interleave(transformations, repeats=8, dim=1).flatten()
idx = transformations == transformation
base_descriptors = all_base_descriptors[idx]
transformed_descriptors = all_transformed_descriptors[idx]
print(f'Transformation Tensor: {generate_strain_tensors()[transformation]}')
print(f'Number of Descriptions: {len(base_descriptors)}')
print(f'Mean Cosine Similarity: {torch.nn.functional.cosine_similarity(base_descriptors,transformed_descriptors,dim = 1).mean()}')

Transformation Tensor: (-0.5, -0.5, -0.4)
Number of Descriptions: 1144
Mean Cosine Similarity: 0.09418456388903054


In [193]:
# Create a dataset and dataloader
dataset = TensorDataset(base_descriptors, transformed_descriptors)
# dataset = TensorDataset(all_base_descriptors, all_transformed_descriptors)
train_size = 0.8
val_size = 0.2

train_dataset, val_dataset = random_split(dataset,[train_size,val_size])

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last = True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last = True)

In [194]:
# Define the MLP model
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=64):
        super(MLP, self).__init__()
        self.model = nn.Sequential(
            # nn.Linear(input_dim,output_dim)
            
            nn.Linear(input_dim, hidden_dim),
            # nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
            
            # nn.Linear(input_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.model(x)

In [215]:
# Define residual MLP model
class Residual_MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super(Residual_MLP, self).__init__()
        self.model = nn.Sequential(
            # nn.Linear(input_dim,output_dim)
            
            nn.Linear(input_dim, hidden_dim),
            # nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
            
            # nn.Linear(input_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        residual = self.model(x)
        return x + residual

In [216]:
# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
input_dim = base_descriptors.shape[1]
output_dim = transformed_descriptors.shape[1]
# model = MLP(input_dim,output_dim).double().to(device)
model = Residual_MLP(input_dim,output_dim).double().to(device)

cuda


In [217]:
# Potentially apply identity initialisation to model
identity_initialisation = True

def init_identity(m):
    if isinstance(m, nn.Linear):
        nn.init.eye_(m.weight)  # Initialize as identity
        if m.bias is not None:
            nn.init.zeros_(m.bias)

if identity_initialisation:
    model.apply(init_identity)


In [218]:
# Define loss function and optimizer
def cosine_similarity(output,target,reduction = 'mean'):
    loss = 1 - torch.nn.functional.cosine_similarity(output,target)

    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'none':
        return loss

# criterion = nn.CosineEmbeddingLoss(reduction='none')
# criterion = nn.MSELoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [219]:
# One eval round before training
model.eval()
start_losses = []
with torch.no_grad():
    for base, transformed in val_dataloader:
        base, transformed = base.to(device), transformed.to(device)
        output = model(base)
        # losses = criterion(output,transformed)
        losses = cosine_similarity(output,transformed,'none')
        loss = losses.mean()
        start_losses.extend(losses.cpu().numpy())

start_loss = np.mean(start_losses)
print(f'Loss before training: {start_loss}')

writer = SummaryWriter(log_dir="runs/tain_one_transformation")

# Training loop
epochs = 250
for epoch in range(epochs):
    model.train()
    train_losses = []
    for base, transformed in train_dataloader:
        base, transformed = base.to(device), transformed.to(device)
        
        optimizer.zero_grad()
        output = model(base)
        # losses = criterion(output, transformed)
        losses = cosine_similarity(output,transformed,'none')
        loss = losses.mean()
        loss.backward()
        optimizer.step()
        
        train_losses.extend(losses.cpu().detach().numpy())

    model.eval()
    val_losses = []
    with torch.no_grad():
        for base, transformed in val_dataloader:
            base, transformed = base.to(device), transformed.to(device)
            output = model(base)
            # losses = criterion(output,transformed)
            losses = cosine_similarity(output,transformed,'none')
            val_losses.extend(losses.cpu().numpy())

    avg_train_loss = np.mean(train_losses)
    std_train_loss = np.std(train_losses)
    
    avg_val_loss = np.mean(val_losses)
    std_val_loss = np.std(val_losses)

    writer.add_scalar("Loss/Train", avg_train_loss, epoch)
    writer.add_scalar("Loss_std/Train", std_train_loss, epoch)
    writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
    writer.add_scalar("Loss_std/Validation", std_val_loss, epoch)

    if (epoch+1)%10 == 0:
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")

Loss before training: 0.9087612115807794
Epoch 10/250, Train Loss: 0.734561, Val Loss: 0.736868
Epoch 20/250, Train Loss: 0.657324, Val Loss: 0.671197
Epoch 30/250, Train Loss: 0.619843, Val Loss: 0.639969
Epoch 40/250, Train Loss: 0.592612, Val Loss: 0.619579
Epoch 50/250, Train Loss: 0.569094, Val Loss: 0.603314
Epoch 60/250, Train Loss: 0.546415, Val Loss: 0.589297
Epoch 70/250, Train Loss: 0.526848, Val Loss: 0.576743
Epoch 80/250, Train Loss: 0.504287, Val Loss: 0.565736
Epoch 90/250, Train Loss: 0.485251, Val Loss: 0.556114
Epoch 100/250, Train Loss: 0.467478, Val Loss: 0.547761
Epoch 110/250, Train Loss: 0.450103, Val Loss: 0.540493
Epoch 120/250, Train Loss: 0.434217, Val Loss: 0.534030
Epoch 130/250, Train Loss: 0.419972, Val Loss: 0.528149
Epoch 140/250, Train Loss: 0.406522, Val Loss: 0.523231
Epoch 150/250, Train Loss: 0.393360, Val Loss: 0.518749
Epoch 160/250, Train Loss: 0.381571, Val Loss: 0.514501
Epoch 170/250, Train Loss: 0.369454, Val Loss: 0.510785
Epoch 180/250, T

In [None]:
# Save the trained model
torch.save(model.state_dict(), "models/single_transformation_model.pth")
print("Model training complete and saved.")