# Minimal training Notebook to test the training pipeline

In [1]:
%load_ext autoreload

%autoreload 2

## Imports

In [2]:
import helpers.set_path # needs to be there to set the correct project path

import pandas as pd

from src.data.load_data import get_train_loader, get_val_loader, get_test_loader, classes
from src.data.format_submissions import format_submissions

from pathlib import Path
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

import wandb
import torch
import os

In [3]:
from src.models import ViTModel as Model

torch.set_float32_matmul_precision('high')

In [4]:
MAX_EPOCHS = 50
DELETE_MODEL_CHECKPOINTS = True

BATCH_SIZE = 32
IMAGE_SIZE = 224
CROP_THRESHOLD = 0.05

## Load Data

In [5]:
train_dataloader = get_train_loader(BATCH_SIZE, IMAGE_SIZE, CROP_THRESHOLD)
val_dataloader = get_val_loader(BATCH_SIZE, IMAGE_SIZE, CROP_THRESHOLD)
test_dataloader = get_test_loader(BATCH_SIZE, IMAGE_SIZE, CROP_THRESHOLD)

## Initialize and run training

In [12]:
wandb_logger = WandbLogger(project="ccv1", entity="safari_squad", name=Model.__name__, offline=False)

seed_everything(42)

model = Model(
    batch_size=BATCH_SIZE, 
    image_size=IMAGE_SIZE, 
    crop_threshold=CROP_THRESHOLD, 
    lr=1e-6
)

trainer = Trainer(max_epochs=MAX_EPOCHS, accelerator="gpu", logger=wandb_logger, callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", patience=3),
        ModelCheckpoint(dirpath=Path('../models/checkpoints'), filename=f"{model.__class__.__name__}_{wandb_logger.version}", monitor="val_loss", mode="min", save_top_k=1)
    ])

trainer.fit(model, 
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader)  

best_model = model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

submissions = format_submissions(
        trainer.predict(best_model, dataloaders=test_dataloader),
        classes
    )

trainer.validate(best_model, dataloaders=val_dataloader)

wandb_logger.log_text("submission", dataframe=submissions)

wandb.finish()

Global seed set to 42
Some weights of the model checkpoint at google/vit-large-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-large-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU avai

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
Some weights of the model checkpoint at google/vit-large-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-large-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to

Predicting: 0it [00:00, ?it/s]