In [None]:
import sys
sys.path.append('..')

In [None]:
import MinkowskiEngine as ME
import torch
from torchvision import transforms
import random

import plotly.express as px
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
pwd

In [None]:
from data_utils import *
from data.PILArNet import PILArNetDataModule
from torch.utils.data import Subset

dataset = PILArNetDataModule(
    #data_path="../pilarnet/train/*.h5",
    #data_path="../../pilarnet/train/*.h5",
    data_path="/sdf/data/neutrino/carsmith/foundation_models/pilarnet_model/pilarnet/train/*.h5",
    batch_size=48,  # 24 events per batch
    num_workers=0,
    dataset_kwargs={
        "emin": 1.0e-2,  # min energy for log transform
        "emax": 20.0,  # max energy for log transform
        "energy_threshold": 0.13,  # remove points with energy < 0.13
        "remove_low_energy_scatters": True,  # remove low energy scatters (PID=4)
        #"maxlen": -1,  # max number of events to iterate over
        "maxlen": 20000, # taking only first 100 events
        "min_points": 1024, # minimum number of points in an event
    },
)
dataset.setup()

# DataLoader
train_loader = dataset.train_dataloader()
# subset_indices = list(range(100))  # take first 100 samples
# train_subset = Subset(dataset.train_dataset, subset_indices)
# subset_loader = torch.utils.data.DataLoader(
#     train_subset,
#     batch_size=48,
#     shuffle=True,
#     num_workers=0,
#     collate_fn=dataset.train_dataloader().collate_fn,  # use same collate
# )

In [5]:
# pick a data
# for batch in train_loader:
#     points = batch['points']
#     lengths = batch['lengths']
#     break
    
# # difference - for cifar, data loader does transforms
# # transformed_data = [transform(pc) for pc in raw_data]
# print(points[0, :, :].shape)
# data = points[0, :, :]
# transform = compute_train_transform(seed=45)
# x1 = transform(data)
# x2 = transform(data)

In [6]:
# fig = go.Figure(data=[
#     go.Scatter3d(
#         x=x1[:, 0], y=x1[:, 1], z=x1[:, 2],
#         mode='markers',
#         marker=dict(size=5, color='red'),
#         name='x1'
#     ),
#     go.Scatter3d(
#         x=x2[:, 0], y=x2[:, 1], z=x2[:, 2],
#         mode='markers',
#         marker=dict(size=5, color='blue'),
#         name='x2'
#     ),
#     go.Scatter3d(
#         x=data[:, 0], y=data[:, 1], z=data[:, 2],
#         mode='markers',
#         marker=dict(size=5, color='orange'),
#         name='original'
#     )  
# ])

# fig.show()

In [7]:
# converting 2 transformed views into sparse tensors
# for batch in train_loader:
#     for pc in batch['points']:
#         x1 = transform(pc)
#         x2 = transform(pc)
#         break
#     break
    
# device = 'cuda'
# x1 = torch.tensor(x1).to(device)
# x2 = torch.tensor(x2).to(device)

# coords = [x1[:, :3], x2[:, :3]]  # list of point clouds, each shape (Ni, 3)
# feats = [x1[:, 3:], x2[:, 3:]] # list of energies for each point cloud
# voxel_size = 0.05 # change to be real

# sparse_tensors = []

# for i, pc in enumerate(coords):
#     quantized_coords = torch.floor(pc / voxel_size).int()
    
#     # coordinates = ME.utils.batched_coordinates(quantized_coords)
#     batch_index = torch.full((quantized_coords.shape[0], 1), i, dtype=torch.int32, device=quantized_coords.device)
#     coords_with_batch = torch.cat([batch_index, quantized_coords], dim=1)  # shape (n, 4)
    
#     sparse_tensor = ME.SparseTensor(
#         features=feats[i].float(),           # shape (n, C)
#         coordinates=coords_with_batch      # shape (n, 1 + 3)
#     )
#     sparse_tensors.append(sparse_tensor)

# print(f'Input sizes for x1, x2: {sparse_tensors[0].shape, sparse_tensors[1].shape}')

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [6]:
# preparing training data - transforming and converting to sparse tensors in new dataloader
from data_utils import *

transform = compute_train_transform(seed=45)

raw_pointclouds = []
train_loader = dataset.train_dataloader()
for batch in tqdm(train_loader, desc="Extracting raw events"):
    for pc in batch["points"]:
        raw_pointclouds.append(pc.cpu().numpy())  # store as NumPy arrays or torch tensors

simclr_dataset = SimCLRPointCloudDataset(raw_pointclouds, transform, voxel_size=0.01, device="cuda")

def simclr_collate(batch):
    x1, x2 = zip(*batch)  # each is a list of SparseTensors
    return list(x1), list(x2)

simclr_loader = torch.utils.data.DataLoader(
    simclr_dataset,
    #batch_size=16,
    batch_size=256,
    shuffle=True,
    collate_fn=simclr_collate,
    num_workers=0,
    pin_memory=True,
)

Extracting raw events: 100%|██████████| 208/208 [00:08<00:00, 24.13it/s]


In [7]:
# Preparing validadtion data
from torch.utils.data import DataLoader, random_split
mnist3d = MNIST3DExtrudedDataset(train=False, depth=3, voxel_size=1.0, device='cuda')

val_size = int(0.2 * len(mnist3d))         # 20% for validation
train_size = len(mnist3d) - val_size       # 80% for training

mnist_train_dataset, mnist_val_dataset = random_split(mnist3d, [train_size, val_size])
collate = lambda x: tuple(zip(*x))

mnist_train_loader = DataLoader(mnist_train_dataset, batch_size=16, shuffle=True, collate_fn=collate)
mnist_val_loader   = DataLoader(mnist_val_dataset,   batch_size=16, shuffle=False, collate_fn=collate)

In [None]:
import wandb

wandb.init(
    project="simclr_encoder_pretraining",
    name="simclr-run-1",
    config={
        "epochs": epochs,
        "batch_size": 16,
        "lr": 1e-3,
        "temperature": 0.07,
        # Add more config params if you'd like
    }
)

In [8]:
from model import *
from loss import *
import torch.optim as optim
import wandb

%autosave 120

epochs = 5
device = 'cuda'

# try tracking with wandb
wandb.init(
    project="simclr_encoder_pretraining",
    name="simclr-run-1",
    config={
        "epochs": epochs,
        "batch_size": 256,
        "lr": 1e-4,
        "temperature": 0.07,
    }
)

# instantiate model - currently, out_features in UNet_Encoder constructor is output of projection head
model = UNet_Encoder(in_channels=1)
wandb.watch(model, log="all", log_freq=100)
model.to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) # trying out learning rate decay
results = {}
results['eval_loss'] = []
results['eval_top1'] = []

for epoch in range(1, epochs + 1):
    print("test")
    train_loss = train_unet(model, simclr_loader, optimizer, epoch, epochs)
    scheduler.step()
    torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss,}, 'checkpoint.pth')
    wandb.log({
        "simclr_train_loss": train_loss,
        "learning_rate": scheduler.get_last_lr()[0],
        "epoch": epoch,
    })
    
    # validation with classification head
    classifier = LinearProbe(out_dim=512, num_classes=10) # encoder embeddings have dim=512
    criterion = torch.nn.CrossEntropyLoss()
    val_optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3, weight_decay=1e-4)
    eval_loss, eval_acc = mnist_validate(model, classifier, mnist_train_loader, mnist_val_loader, criterion, val_optimizer, epochs=1)
    wandb.log({
        "mnist_val_loss": eval_loss,
        "mnist_val_accuracy": eval_acc,
        "epoch": epoch,
    })
    results['eval_loss'].append(eval_loss)
    results['eval_top1'].append(eval_acc)

np.save('results.npy', results)
wandb.finish()

Autosaving every 120 seconds
test


                                                                         

[Eval] Epoch final | Loss: 2.1481 | Accuracy: 0.4645
test


                                                                         

[Eval] Epoch final | Loss: 2.1522 | Accuracy: 0.4465
test


                                                                       

KeyboardInterrupt: 

In [None]:
print(results['eval_loss'])
print(results['eval_top1'])