In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
import pickle

import matplotlib.pyplot as plt
import os


In [3]:
train_data_path = '/home/CAMPUS/hdasari/apebench_experiments/final_2d_experiments/vanilla_unet_2d/data_splits/train_100.npy'
val_data_path = '/home/CAMPUS/hdasari/apebench_experiments/final_2d_experiments/vanilla_unet_2d/data_splits/val.npy'
test_data_path = '/home/CAMPUS/hdasari/apebench_experiments/final_2d_experiments/vanilla_unet_2d/data_splits/test.npy'

In [4]:
# from src_codes.model_perform.training import training_loop
from src_codes.models.primary_func import PrimaryNetwork
# from src_codes.DataLoaders.KsDataset import KSDataset

In [9]:
unet_1d_weights_path = '/home/CAMPUS/hdasari/apebench_experiments/mse_experiments/vanilla_1d/checkpoints/new_june18_2_mse_epoch_20_unet_1d_weights_biases.pth'


device = "cuda" if torch.cuda.is_available() else "cpu"
model = PrimaryNetwork(unet_1d_weights_path=unet_1d_weights_path, device=device).to(device)
criterion = nn.MSELoss(reduction='mean')  
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
epochs = 30

  state_dict_1d = torch.load(path, map_location=self.device)


In [10]:
import torch
from torch.utils.data import Dataset
import numpy as np

class KSBatchDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return self.inputs.shape[0]

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]



In [11]:
ks_train_data = np.load(train_data_path)
ks_val_data = np.load(val_data_path)
ks_test_data = np.load(test_data_path)

train_tensor = torch.tensor(ks_train_data, dtype=torch.float32)
val_tensor = torch.tensor(ks_val_data, dtype=torch.float32)
test_tensor = torch.tensor(ks_test_data, dtype=torch.float32)

ks_train = train_tensor.view(-1,43,1,160,160)
ks_val = val_tensor.view(-1,43,1,160,160)
ks_test = test_tensor.view(-1,43,1,160,160)

ks_x_train_data = ks_train[:,0:-1,:,:]
ks_y_train_data = ks_train[:,1:,:,:]

print("ks_x_train_data shape:", ks_x_train_data.shape)
print("ks_y_train_data shape:", ks_y_train_data.shape)

ks_x_val_data = ks_val[:,0:-1,:,:]
ks_y_val_data = ks_val[:,1:,:,:]

train_dataset = KSBatchDataset(ks_x_train_data, ks_y_train_data)
val_dataset = KSBatchDataset(ks_x_val_data, ks_y_val_data)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

ks_x_train_data shape: torch.Size([1500, 42, 1, 160, 160])
ks_y_train_data shape: torch.Size([1500, 42, 1, 160, 160])


In [12]:
import torch
from tqdm import tqdm
import os


def training_loop(model, device, criterion, optimizer,scheduler ,train_loader, val_loader, epochs, len_train_dataset, len_val_dataset, storing_path, res_path):

    train_losses = []
    val_losses = []
    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0

        for inputs, targets in tqdm(train_loader, desc=f"[Train Epoch {epoch+1}/{epochs}]"):
            inputs, targets = inputs.squeeze(0), targets.squeeze(0)
            inputs, targets = inputs.to(device), targets.to(device)

            output = model(inputs)

            loss = criterion(output, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.detach().cpu().item() * inputs.size(0)
            inputs.detach()
            targets.detach()

        avg_train_loss = train_loss / len_train_dataset
        train_losses.append(avg_train_loss)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc=f"[Val Epoch {epoch+1}/{epochs}]"):
                inputs, targets = inputs.squeeze(0), targets.squeeze(0)
                inputs, targets = inputs.to(device), targets.to(device)
                output = model(inputs)
                loss = criterion(output, targets)
                
                val_loss += loss.detach().cpu().item() * inputs.size(0)
                inputs.detach()
                targets.detach()

        avg_val_loss = val_loss / len_val_dataset
        val_losses.append(avg_val_loss)

        log_line = f"Epoch {epoch+1}/{epochs} | Train MSE: {avg_train_loss:.6f} | Val MSE: {avg_val_loss:.6f}"
        print(log_line)
        with open(os.path.join(res_path,'training_log.txt'), "a") as f:
            f.write(log_line + "\n")
        
        if (epoch + 1) % 5 == 0:
            save_path = os.path.join(storing_path, f"model_epoch_{epoch+1}.pth")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss
            }, save_path)
            print(f"Saved checkpoint at epoch {epoch+1} to {save_path}")
        
        print("\n")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss
            }, os.path.join(storing_path, 'best_model_checkpoint.pth'))
            print(f" Saved BEST checkpoint at epoch {epoch+1}")
            with open(os.path.join(res_path, 'training_log.txt'), "a") as f:
                f.write(f"Saved BEST checkpoint at epoch {epoch+1}\n")

        if scheduler is not None:
            scheduler.step()

    return train_losses, val_losses


In [13]:
data_storing_path = '/home/CAMPUS/hdasari/apebench_experiments/mse_experiments/extrusion_2d/checkpoints/new_check_june18'

results_storing_path = '/home/CAMPUS/hdasari/apebench_experiments/mse_experiments/extrusion_2d/results/new_results_june18'

train_losses, val_losses = training_loop(model,device, criterion, optimizer, None,train_loader, val_loader, epochs, len(train_dataset), len(val_dataset), data_storing_path, results_storing_path)

with open(os.path.join(results_storing_path , 'train_losses.pkl'), 'wb') as f:
    pickle.dump((train_losses, val_losses), f)

with open(os.path.join(results_storing_path, 'training_log.txt'), "a") as f:
    f.write("Training completed\n")


[Train Epoch 1/30]:   4%|▍         | 58/1500 [00:22<09:22,  2.56it/s] 


KeyboardInterrupt: 