In [1]:
import os

import sys
sys.path.append('..')

import toml

import matplotlib.pyplot as plt

import torch
import numpy as np

from models.vision_transformer import ViT
from data.datasets import TimeSeriesDataset
from torch.utils.data import DataLoader

torch.autograd.set_detect_anomaly(True)

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
config_path = '/glade/u/home/jshen/pruning-turbulence-vit/src/config/prune/l1_rollout.toml'
config = toml.load(config_path)
# config['checkpoint_file'] = '/glade/derecho/scratch/jshen/base/best.tar'

state_dict = torch.load(config['checkpoint_file'], map_location=device, weights_only=False)
optimizer_state = state_dict.pop('optimizer_state', None)
model_state_dict = state_dict.pop('model_state', state_dict)

# Initialize model

model = ViT(**config['model']).to(device)
model.load_state_dict(model_state_dict)

# Initialize dataset

config['train_dataset']['target_step'] *= config['finetuning']['num_rollout_steps']
dataset = TimeSeriesDataset(**config['train_dataset'])

In [None]:
# Single-process DDP emulation
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("LOCAL_RANK", "0")

use_cuda = torch.cuda.is_available()
backend = "nccl" if use_cuda else "gloo"

dist.init_process_group(
    backend=backend,
    init_method="env://",
    rank=int(os.environ["RANK"]),
    world_size=int(os.environ["WORLD_SIZE"]),
)

local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"cuda:{local_rank}" if use_cuda else "cpu")
if use_cuda:
    torch.cuda.set_device(device)

# Adjust batch size to mirror script behavior
config['finetuning']['batch_size'] = int(config['finetuning']['batch_size']) // int(os.environ["WORLD_SIZE"]
)

# Wrap model with DDP
model = DDP(
    model,
    device_ids=[local_rank] if use_cuda else None,
    output_device=local_rank if use_cuda else None,
)

In [None]:
from torch.utils.data.distributed import DistributedSampler

sampler = DistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(dataset, batch_size=config['finetuning']['batch_size'], sampler=sampler, num_workers=4)

optimizer = torch.optim.AdamW(model.parameters(), lr=config['finetuning']['lr'], weight_decay=config['finetuning']['weight_decay'])
if optimizer_state is not None: 
    optimizer.load_state_dict(optimizer_state)

In [None]:
sampler.set_epoch(0)

model.train()
for i, (ic, target) in enumerate(dataloader):
    if i == 4: break
    ic, target = ic.to(device), target.to(device)

    optimizer.zero_grad(set_to_none=True)

    # Mirror train_one_epoch rollout behavior
    for _ in range(config['finetuning']['num_rollout_steps']):
        y_pred = model(ic)
        prev_ic = ic[:, :, :-1, :, :].contiguous()
        ic = torch.cat([y_pred, prev_ic], dim=2)

    loss = torch.nn.functional.mse_loss(y_pred, target)
    loss.backward()
    print(loss.item())
    optimizer.step()

# Clean up process group
import torch.distributed as dist
if dist.is_initialized():
    dist.destroy_process_group()

0.00034157533082179725
0.00036821269895881414
0.0003726504510268569
0.0003312506014481187
