# Install Dependencies

In [None]:
pip install torch

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Collectin

In [None]:
pip install torchvision submitit pyyaml numpy opencv-python gdown

Collecting submitit
  Downloading submitit-1.5.1-py3-none-any.whl (74 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/74.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.7/74.7 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: submitit
Successfully installed submitit-1.5.1


# Definitions and imports

In [None]:
import torch
import torch.nn as nn
import h5py
import numpy as np
import json
from torch.utils.data import SubsetRandomSampler
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt
import gdown

In [None]:
class GsocDataset3(torch.utils.data.Dataset):
    def __init__(self, h5_path, transforms=None, preload_size=3200):
        self.h5_path = h5_path
        self.transforms = transforms
        self.preload_size = preload_size
        self.h5_file = h5py.File(self.h5_path, 'r', libver='latest', swmr=True)
        self.data = self.h5_file['jet']
        self.labels = self.h5_file['Y']
        self.dataset_size = self.data.shape[0]
        self.chunk_size = self.data.chunks

        self.preloaded_data = None
        self.preloaded_labels = None
        self.preload_start = -1

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        preload_start = (idx // self.preload_size) * self.preload_size
        if preload_start != self.preload_start:
            self.preload_start = preload_start
            preload_end = min(preload_start + self.preload_size, self.dataset_size)
            self.preloaded_data = self.data[preload_start:preload_end]
            self.preloaded_labels = self.labels[preload_start:preload_end]

        local_idx = idx - self.preload_start
        data = self.preloaded_data[local_idx]
        data = torch.from_numpy(data)

        labels = self.preloaded_labels[local_idx]
        if self.transforms:
            data = self.transforms(data)
        return data, torch.from_numpy(labels)

    def __del__(self):
        self.h5_file.close()

class ChunkedSampler(torch.utils.data.Sampler):
    def __init__(self, data_source, chunk_size=3200, shuffle=False):
        self.data_source = data_source
        self.chunk_size = chunk_size
        self.num_chunks = len(data_source) // chunk_size
        self.indices = list(range(len(data_source)))
        self.shuffle = shuffle

    def shuffle_indices(self):
        chunk_indices = [self.indices[i * self.chunk_size:(i + 1) * self.chunk_size] for i in range(self.num_chunks)]
        np.random.shuffle(chunk_indices)
        self.indices = [idx for chunk in chunk_indices for idx in chunk]

    def __iter__(self):
        if self.shuffle:
            self.shuffle_indices()
        return iter(self.indices)

    def __len__(self):
        return len(self.data_source)

# Clone I-JEPA github directory

In [None]:
!git clone https://github.com/3podi/ijepa_gsoc.git

Cloning into 'ijepa_gsoc'...
remote: Enumerating objects: 174, done.[K
remote: Counting objects: 100% (131/131), done.[K
remote: Compressing objects: 100% (59/59), done.[K
remote: Total 174 (delta 81), reused 94 (delta 65), pack-reused 43[K
Receiving objects: 100% (174/174), 50.29 KiB | 2.29 MiB/s, done.
Resolving deltas: 100% (92/92), done.


# Run pre-training script

In [None]:
!ls

CODE_OF_CONDUCT.md  CONTRIBUTING.md  main_dist.py	  main.py    src
configs		    LICENSE	     main_distributed.py  README.md


In [None]:
%cd ijepa_gsoc
%mkdir /Data
%mkdir /Logging
%mkdir /Logging/vit_base
gdown.download('https://drive.google.com/uc?id=1PYFc_sGkFu91k-bhLG0UNmHXY0XXm_DI', '/Data/after_outlier_mean_std_record_dataset.json', quiet=False)
gdown.download('https://drive.google.com/uc?id=1fgVCK9xw_ydXOtM-6haY_mcb0hU7P1DD', '/Data/Dataset_normalized.h5', quiet=False)
gdown.download('https://drive.google.com/uc?id=1RsxytIC9-e_UFYZLv2fToASLPasILePh', '/Data/Dataset_Specific_labelled.h5', quiet=False)

/content/ijepa_gsoc


In [None]:
!python main.py \
   --fname /content/ijepa_gsoc/configs/vit_b_14.yaml \
   --devices cuda:0

INFO:root:loaded params...
{   'data': {   'batch_size': 256,
                'color_jitter_strength': 0.0,
                'crop_scale': [0.3, 1.0],
                'crop_size': 126,
                'image_folder': '/content/drive/MyDrive/Data/Dataset_normalized.h5',
                'num_workers': 2,
                'pin_mem': True,
                'root_path': '/content/drive/MyDrive/Data/Dataset_normalized.h5',
                'use_color_distortion': False,
                'use_gaussian_blur': False,
                'use_horizontal_flip': False},
    'logging': {   'folder': '/content/drive/MyDrive/Logging/vit_base_lr_fixed',
                   'write_tag': 'ijepa_gsoc'},
    'mask': {   'allow_overlap': False,
                'aspect_ratio': [0.75, 1.5],
                'enc_mask_scale': [0.85, 1.0],
                'min_keep': 10,
                'num_enc_masks': 1,
                'num_pred_masks': 4,
                'patch_size': 14,
                'pred_mask_scale': [0.15, 0.2

# Train classification head

In [None]:
import src.models.vision_transformer as vit
encoder_ijepa = vit.__dict__['vit_base'](
        img_size=[126],
        patch_size=14)

In [None]:
ijepa_path = '/Logging/vit_base/ijepa_gsoc-ep8.pth.tar' #chose the epoch you want
checkpoint = torch.load(ijepa_path)

encoder_ijepa.load_state_dict(checkpoint['target_encoder'])


In [None]:
class FromPretrained(nn.Module):
    def __init__(self, pretrained_model, hidden_dim, num_classes, req_grad=False):
        super().__init__()
        self.pretrained_model = pretrained_model

        for param in self.pretrained_model.parameters():
             param.requires_grad = req_grad
        self.linear = nn.Linear(hidden_dim, num_classes)
        self.batchnorm = nn.BatchNorm1d(hidden_dim)


    def forward(self, x):

        #self.pretrained_model.eval()
        #with torch.no_grad():
        x = self.pretrained_model(x)

        x = x.sum(dim=1)
        x = self.batchnorm(x)

        return self.linear(x)

In [None]:
pretrained_model = FromPretrained(pretrained_model=encoder_ijepa,
                                  hidden_dim=768,
                                  num_classes=1,
                                  req_grad=False)

In [None]:
# Verify parameter inclusion
original_model_params = sum(p.numel() for p in encoder_ijepa.parameters())
wrapped_model_params = sum(p.numel() for p in pretrained_model.parameters())
trainable_params = sum(p.numel() for p in pretrained_model.parameters() if p.requires_grad is True)
print('Original model n of parameters: ', original_model_params)
print('Number of parameters in the classifier: ', wrapped_model_params-original_model_params)
print('Trainable params: ', trainable_params)

Original model n of parameters:  86323200
Number of parameters in the classifier:  2305
Trainable params:  86325505


In [None]:
batch_size = 256
pin_mem = True
num_workers = 2

In [None]:
labelled_path = '/Data/Dataset_Specific_labelled.h5'

dataset = GsocDataset3(labelled_path, preload_size=256)
dataset_size = len(dataset)
indices = list(range(dataset_size))
np.random.shuffle(indices)

train_size = int(0.8 * dataset_size)
train_indices, val_indices = indices[:train_size], indices[train_size:]

train_sampler = ChunkedSampler(train_indices, chunk_size=batch_size, shuffle=True)
val_sampler = ChunkedSampler(val_indices, chunk_size=batch_size, shuffle=False)

train_data_loader = torch.utils.data.DataLoader(dataset,
                                                batch_size=batch_size,
                                                sampler=train_sampler,
                                                pin_memory=pin_mem,
                                                num_workers=num_workers)

val_data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              sampler=val_sampler,
                                              pin_memory=pin_mem,
                                              num_workers=num_workers)

In [None]:
def lr_lambda(epoch):
    warmup_epochs = 5
    decay_rate = 0.8
    if epoch < warmup_epochs:
        # Linear increase during warmup
        return float(epoch + 1) / warmup_epochs
    else:
        # Exponential decay after warmup
        return decay_rate ** (epoch - warmup_epochs + 1)

In [None]:
optimizer = torch.optim.AdamW(pretrained_model.parameters(), lr=0.00005)
loss_function = nn.BCEWithLogitsLoss()
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_model = pretrained_model.to(device)

In [None]:
stats_path = '/Data/after_outlier_mean_std_record_dataset.json'

with open(stats_path, 'r') as file:
    stats = json.load(file)
    mean = np.array(stats['after_outlier_mean'])
    std = np.array(stats['after_outlier_std'])

In [None]:
%mkdir /Logging/vit_base/Classification/

In [None]:
save_path = '/Logging/vit_base/Classification/'

def save_checkpoint():
    save_dict = {
        'model': pretrained_model.state_dict(),
        }

    torch.save(save_dict, save_path + 'classification_model.pth.tar')

In [None]:
from src.utils.logging import CSVLogger

log_file1 = '/Logging/vit_base/Classification/classification_log.csv'
csv_logger1 = CSVLogger(log_file1,
                      ('%d', 'epoch'),
                      ('%d', 'itr'),
                      ('%.5f', 'train_loss'))

log_file2 = '/Logging/vit_base/Classification/epoch_classification_log.csv'
csv_logger2 = CSVLogger(log_file2,
                      ('%d', 'epoch'),
                      ('%.5f', 'train_loss'),
                      ('%.5f', 'val_loss'),
                      ('%.5f', 'train_acc'),
                      ('%.5f', 'val_acc'))

In [None]:
epochs = 50

# Early stopping parameters
patience = 5
best_val_loss = float('inf')
epochs_no_improve = 0
delta = 0.001

In [None]:
for epoch in range(epochs):
    pretrained_model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    for batch_idx, data in enumerate(train_data_loader):
        inputs, labels = data
        inputs = inputs.permute(0, 3, 1, 2)
        padding = (0, 1, 0, 1)
        inputs = torch.nn.functional.pad(inputs, padding, mode='constant', value=0)
        inputs = (inputs - mean.reshape(1, 8, 1, 1)) / std.reshape(1, 8, 1, 1)

        inputs = inputs.float().to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = pretrained_model(inputs)

        # Loss computation
        loss = loss_function(outputs, labels)
        total_loss += loss.item()

        # Calculate predictions for accuracy
        predicted_probabilities = torch.sigmoid(outputs)
        predicted_labels = (predicted_probabilities > 0.5).float()
        correct_predictions += (predicted_labels == labels).sum().item()
        total_predictions += labels.size(0)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #scheduler.step(epoch + batch_idx / len(train_data_loader))
        scheduler.step()

        csv_logger1.log(epoch + 1, batch_idx, loss)


    epoch_loss = total_loss / len(train_data_loader)
    epoch_accuracy = correct_predictions / total_predictions * 100
    print(f'Epoch {epoch+1}\nTrain loss: {epoch_loss:.4f}, Train accuracy: {epoch_accuracy:.2f}%')

    # Validation loop
    pretrained_model.eval()
    val_loss = 0.0
    val_correct_predictions = 0
    val_total_predictions = 0

    with torch.no_grad():
        for val_batch_idx, val_data in enumerate(val_data_loader):
            val_inputs, val_labels = val_data
            val_inputs = val_inputs.permute(0, 3, 1, 2)
            padding = (0, 1, 0, 1)
            val_inputs = torch.nn.functional.pad(val_inputs, padding, mode='constant', value=0)
            val_inputs = (val_inputs - mean.reshape(1, 8, 1, 1)) / std.reshape(1, 8, 1, 1)

            val_inputs = val_inputs.float().to(device)
            val_labels = val_labels.to(device)

            # Forward pass
            val_outputs = pretrained_model(val_inputs)

            # Loss computation
            val_loss += loss_function(val_outputs, val_labels).item()

            # Calculate predictions for accuracy
            val_predicted_probabilities = torch.sigmoid(val_outputs)
            val_predicted_labels = (val_predicted_probabilities > 0.5).float()
            val_correct_predictions += (val_predicted_labels == val_labels).sum().item()
            val_total_predictions += val_labels.size(0)

    val_epoch_loss = val_loss / len(val_data_loader)
    val_epoch_accuracy = val_correct_predictions / val_total_predictions * 100
    print(f'Validation loss: {val_epoch_loss:.4f}, Validation accuracy: {val_epoch_accuracy:.2f}%')

    csv_logger2.log(epoch + 1, epoch_loss, val_epoch_loss, epoch_accuracy, val_epoch_accuracy)

    if best_val_loss - val_epoch_loss > delta:
        best_val_loss = val_epoch_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f'Early stopping triggered after {epoch+1} epochs')
        break



  return F.conv2d(input, weight, bias, self.stride,


Epoch 1
Train loss: 0.6226, Train accuracy: 64.48%
Validation loss: 1.3113, Validation accuracy: 44.30%
Epoch 2
Train loss: 0.5988, Train accuracy: 67.16%
Validation loss: 0.6283, Validation accuracy: 66.65%
Epoch 3
Train loss: 0.5988, Train accuracy: 67.16%
Validation loss: 0.6239, Validation accuracy: 66.95%
Epoch 4
Train loss: 0.5988, Train accuracy: 67.16%
Validation loss: 0.6233, Validation accuracy: 66.95%
Epoch 5
Train loss: 0.5988, Train accuracy: 67.16%
Validation loss: 0.6239, Validation accuracy: 66.95%
Epoch 6
Train loss: 0.5988, Train accuracy: 67.16%
Validation loss: 0.6239, Validation accuracy: 67.00%
Epoch 7
Train loss: 0.5988, Train accuracy: 67.16%
Validation loss: 0.6241, Validation accuracy: 66.95%
Epoch 8
Train loss: 0.5988, Train accuracy: 67.16%
Validation loss: 0.6239, Validation accuracy: 66.90%
Early stopping triggered after 8 epochs


In [None]:
save_checkpoint()