In [1]:
'''
NEXT:
* Save checkpoints to /sdf/data: yes
* Split data into train/val/test
* Save 1 image per val epoch
* Test method: visualize full patches + reconstructions, all tokens + reconstructions
'''

'\nNEXT:\n* Save checkpoints to /sdf/data: yes\n* Split data into train/val/test\n* Save 1 image per val epoch\n* Test method: visualize full patches + reconstructions, all tokens + reconstructions\n'

In [2]:
import torch
import numpy as np
from preparation import *
from data_utils import *
from transformer_things import *
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Subset
import importlib
import numpy as np
from utils import *
import wandb
import torch.optim as optim
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

from pos_embed_model import *
import pickle

# paths
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


### Loading in Data

In [3]:
# Dummy Data
# data = np.random.rand(5, 250, 250) # assume each value in the voxel is a fluorescence intensity
# data_list = [data] * 50
load_patches = np.load('/sdf/data/neutrino/carsmith/all_global_norm_patches.npy') # array of shape (2519, 5, 250, 250)
data_list = load_patches.tolist()

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f338c625c00>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


In [4]:
print(load_patches.shape)
dataset = CubeDataset(load_patches)

val_ratio = 0.1
test_ratio = 0.05
total_size = len(dataset)
val_size = int(total_size * val_ratio)
test_size = int(total_size * test_ratio)
train_size = total_size - val_size - test_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

batch_size = 64

# def custom_collate_fn(batch, mask_percentage=0.6, kernel=12):
#     cubes = torch.stack([torch.tensor(cube, dtype=torch.float32) for cube in batch])  # (B, Z, Y, X)
#     return cubes # all are (B, Z, Y, X)
def custom_collate_fn(batch, mask_percentage=0.6, kernel=12):
    cubes = torch.stack([torch.tensor(cube, dtype=torch.float32) for cube in batch])  # (B, Z, Y, X)
    cubes = cubes.unsqueeze(1)  # (B, 1, Z, Y, X)
    return cubes

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=custom_collate_fn,
    num_workers=0,
    pin_memory=False,
    drop_last=True
)

# used to test model capacity by overfitting on a small training set
overfit_loader = DataLoader(
    Subset(train_dataset, list(range(256))),  # first 10 samples
    batch_size=batch_size,
    shuffle=True,
    collate_fn=custom_collate_fn,
    num_workers=0,
    pin_memory=False,
    drop_last=False # we don't give it a full batch
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,  # don't shuffle for validation
    collate_fn=custom_collate_fn,
    num_workers=0,
    pin_memory=False,
    drop_last=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,  # don't shuffle for test
    collate_fn=custom_collate_fn,
    num_workers=0,
    pin_memory=False,
    drop_last=False
)

(2519, 5, 250, 250)


In [None]:
# Testing data loading
cubes = next(iter(train_loader))
cubes = cubes.to(device)
print(len(cubes))
print(f"Cubes shape w/ feature dimension: {cubes.shape}")

test_cubes = next(iter(overfit_loader))
print(len(test_cubes))

for batch in val_loader:
    print(type(batch))
    print(batch.shape)
    break

  cubes = torch.stack([torch.tensor(cube, dtype=torch.float32) for cube in batch])  # (B, Z, Y, X)


### Instantiating le Model

In [None]:
patch_size = (5, 10, 10)
embed_dim = 384
px, py, pz = patch_size
output_dim = px * py * pz
masking_ratio = 0.2 # simple autoencoder test

# all the guys
patch_model = PatchEmbed3D(patch_size=patch_size, embed_dim=embed_dim).to(device)
pos_model = LearnedPositionalEncoder(in_dim=3, embed_dim=embed_dim).to(device)
transformer_encoder = VisionTransformer3D(input_dim=embed_dim).to(device)
transformer_decoder = Decoder(embed_dim=embed_dim, hidden_dim=1024, num_layers=2, output_dim=output_dim, num_heads=4).to(device)

# training params
epochs = 400
model = Model(patch_model, pos_model, transformer_encoder, transformer_decoder, patch_size, embed_dim, output_dim, device, masking_ratio, patch_size)
model = model.to(device)

In [None]:
lr=1e-4
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) 

# Warmup for first 5 epochs
warmup_epochs = 5
total_epochs = epochs
decay_epochs = total_epochs - warmup_epochs

# warmup schedule
warmup_scheduler = LinearLR(optimizer, start_factor=1e-2, end_factor=1.0, total_iters=warmup_epochs)

# cosine annealing after warmup
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=decay_epochs, eta_min=1e-5)  # <- eta_min prevents LR from collapsing

scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_epochs])

# per epoch stats
results = {}
results['train_loss'] = []
# results['val_loss'] = []
results['val_loss'] = []

wandb.init(
    project="mae_pretraining",
    name="stratified_training",
    config={
        "epochs": epochs,
        "batch_size": batch_size,
        "lr": lr
    }
)

for epoch in range(1, epochs+1):
    train_loss = pretrain(model, train_loader, optimizer, epoch, epochs, device, patch_size, batch_size)
    results['train_loss'].append(train_loss)
    val_loss = validate(model, val_loader, epoch, device, patch_size, batch_size=batch_size, save_dir='stratified_val_plots')
    results['val_loss'].append(val_loss)
    # overfit_loss = overfit_test(model, overfit_loader, epoch, device, patch_size, batch_size, save_dir='smalltoken_mask_plots', plot_num = 1)
    # results['overfit_loss'].append(overfit_loss)

    wandb.log({
        "pretrain_loss": train_loss,
        # "overfit_nopos_loss": overfit_loss,
        "val_loss": val_loss,
        "learning_rate": scheduler.get_last_lr()[0],
        "epoch": epoch,
    })

    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
        }, f'/sdf/data/neutrino/carsmith/strat_mae_ckpts/checkpoint_{epoch}.pth')

    scheduler.step()

np.save('stratified_train_results.npy', results, allow_pickle=True)
test_loss = final_test(model, test_loader, device, patch_size, batch_size, save_dir = 'stratified_trained')
wandb.finish()

In [None]:
# Testing
#def final_test(model, test_loader, device='cuda', patch_size=(5, 10, 10), batch_size=64, save_dir='test_imgs':
np.save('long_train_results.npy', results, allow_pickle=True)
test_loss = final_test(model, test_loader, device, patch_size, batch_size)

In [None]:
print(sum(p.numel() for p in model.parameters() if p.requires_grad))