# Summary

Notebook for distriputed training of koopman operator model.

# Imports/Setup

In [2]:
from accelerate import Accelerator, notebook_launcher
import torch
import numpy as np
import matplotlib.pyplot as plt
import data
import dataset
import model
import evaluation
import training
import distributed
from torch import nn
from diffusers import UNet2DModel
from transformers import ViTModel, ViTConfig
import torch.optim as optim
from safetensors.torch import load_file

In [3]:
class Config:
    # dataset
    path = '/data/users/jupyter-dam724/colliding_solutions'
    solver = 'ros2'
    fixed_seq_len = 216
    batch_size = 16
    ahead = 1
    tail = 1
    aug = False
    
    # device
    device_pref = 'cuda'
    device_ind = None
    
    # model
    epoches = 30
    patience = 10
    lr = 1e-5
    save_path = '/data/users/jupyter-dam724/koopman-vit/checkpoint/'
    from_checkpoint = None
    p = False
    latent_size = 2048
    heads = latent_size // 64
    
    # distribution
    processes = 2
    batch_size = batch_size * processes
    tworkers = 32
    vworkers = 32
    grad_accumulate = 4

# Training

In [4]:
def step(batch, model, criterion):
    x, y = batch
    loss = criterion(*model(x, y))
        
    return loss

In [4]:
def acelerate_ddp():
    accelerator = Accelerator(gradient_accumulation_steps=Config.grad_accumulate)
    
    data_params = {
        'path': Config.path, 
        'device_pref': Config.device_pref, 
        'solver': Config.solver, 
        'fixed_seq_len': Config.fixed_seq_len, 
        'ahead': Config.ahead, 
        'tail': Config.tail,
        'device_ind': Config.device_ind
    }

    _, (x_train_data, y_train_data), (x_valid_data, y_valid_data) = data.main(**data_params)
    
    dataset_params = {
        'x_train_data': x_train_data, 
        'y_train_data': y_train_data, 
        'x_valid_data': x_valid_data, 
        'y_valid_data': y_valid_data, 
        'batch_size': Config.batch_size,
        'tworkers': Config.tworkers, 
        'vworkers': Config.vworkers,
        'aug': Config.aug
    }

    train_dl, valid_dl = dataset.main(**dataset_params)
    
    vitconfig = ViTConfig(
        hidden_size=Config.latent_size,         
        num_attention_heads=Config.heads,   
        intermediate_size=4096, 
        num_hidden_layers=12,
        num_channels=3
    ) 
    vit = ViTModel(vitconfig)
    model.unfreeze(vit)
    vitOperator = model.ViTOperatorFlex(
        vit, 
        batch_size=Config.batch_size, 
        p=Config.p, 
        latent_size=Config.latent_size
    )
    
    if Config.from_checkpoint is not None:
        state_dict = load_file(Config.from_checkpoint)
        vitOperator.load_state_dict(state_dict)

    optimizer = optim.AdamW(vitOperator.parameters(), lr=Config.lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, threshold=1e-3)
    
    # Send everything through `accelerator.prepare`
    train_dl, valid_dl, vitOperator, optimizer, scheduler = accelerator.prepare(
        train_dl, valid_dl, vitOperator, optimizer, scheduler
    )
        
    train_log, valid_log = [], []
    
    training_params = {
        'accelerator': accelerator,
        'train': train_dl, 
        'valid': valid_dl, 
        'model': vitOperator, 
        'epochs': Config.epoches, 
        'patience': Config.patience, 
        'criterion': model.OperatorLoss(1.0), 
        'save_path': Config.save_path, 
        'step': step, 
        'train_log': train_log, 
        'valid_log': valid_log, 
        'optimizer': optimizer, 
        'scheduler': scheduler, 
        'loading_bar': False
    }
    
    training.accelerator_train(**training_params)

In [None]:
notebook_launcher(acelerate_ddp, args=(), num_processes=Config.processes)

Launching training on 2 GPUs.
Now using GPU.
Now using GPU.
Train size: 139607, Percent of toal: 74.66%, Unique instances: 700
Train size: 47394, Percent of toal: 25.34%, Unique instances: 240
Train size: 139607, Percent of toal: 74.66%, Unique instances: 700
Train size: 47394, Percent of toal: 25.34%, Unique instances: 240
Epoch 1/30, Train Loss: 45.123804167616754, Validation Loss: 37.27319822569151
Epoch 2/30, Train Loss: 30.106785648735986, Validation Loss: 24.34978584083351
Epoch 3/30, Train Loss: 21.60133946002049, Validation Loss: 19.219888980968577
Epoch 4/30, Train Loss: 17.63250692455881, Validation Loss: 16.66112452970969
Epoch 5/30, Train Loss: 15.904254290665142, Validation Loss: 15.344738470541465
Epoch 6/30, Train Loss: 15.14991004092056, Validation Loss: 14.78170838484893
Epoch 7/30, Train Loss: 14.749128958218035, Validation Loss: 14.574717275516408
Epoch 8/30, Train Loss: 14.40631745289029, Validation Loss: 14.376058572047466
Epoch 9/30, Train Loss: 14.189087448837231

In [None]:
# Make a notebook for unet next
# 