In [1]:
! pip install mlflow transformers datasets "pydantic[dotenv]" av

Collecting mlflow
  Downloading mlflow-2.0.1-py3-none-any.whl (16.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m75.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting av
  Downloading av-10.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.2/31.2 MB[0m [31m50.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting Flask<3
  Downloading Flask-2.2.2-py3-none-any.whl (101 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.5/101.5 kB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m
Collecting sqlparse<1,>=0.4.0
  Downloading sqlparse-0.4.3-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
Collecting alembic<2
  Downloading alembic-1.8.1-py3-none-any.whl (209 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.8/209.8

In [2]:
%cd ..

/notebooks


In [None]:
!ls

In [4]:
import sys
sys.path.append(".")

In [1]:
import argparse
import logging
import os

import mlflow
from overfit.models.vit import ViT
from overfit.trainers.overfit import OverfitTrainer
from overfit.utils.misc import parse_video_path_params
from torchvision.io import read_video
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.models import ResNet50_Weights, resnet50
from torchvision.models import ResNet34_Weights, resnet34
from torchvision.models import ResNet18_Weights, resnet18
from pathlib import Path
import torch
from tqdm.notebook import tqdm
from mlflow.client import MlflowClient
from overfit.env_settings import settings

In [7]:
MLFLOW_EXPERIMENT_ID=""
CONFIDENCE=0.1
WEIGHT_DECAY=0.2
MAX_LR=0.4
MOMENTUM=0.1
MODELS = ["resnet18", "resnet34", "resnet50", "vit"]
MODEL=MODELS[3]

DATASETS = ["4-50", "5-50", "6-75"]
EXPERIMENT_NAMES = {DATASET: f"D{DATASET}M{MODEL}C{CONFIDENCE}WD{WEIGHT_DECAY}LR{MAX_LR}M{MOMENTUM}" for DATASET in DATASETS}
EXPERIMENT_NAMES

{'4-50': 'D4-50MvitC0.1WD0.2LR0.4M0.1',
 '5-50': 'D5-50MvitC0.1WD0.2LR0.4M0.1',
 '6-75': 'D6-75MvitC0.1WD0.2LR0.4M0.1'}

In [8]:
client = MlflowClient(tracking_uri=settings.MLFLOW_TRACKING_URI)
mlflow.set_tracking_uri(settings.MLFLOW_TRACKING_URI)
MLFLOW_EXPERIMENT_IDS = {}
for DATASET, EXPERIMENT_NAME in EXPERIMENT_NAMES.items():
    try:
        MLFLOW_EXPERIMENT_ID = client.create_experiment(EXPERIMENT_NAME)
    except Exception as e:
        MLFLOW_EXPERIMENT_ID = client.get_experiment_by_name(EXPERIMENT_NAME).experiment_id
    MLFLOW_EXPERIMENT_IDS[DATASET] = MLFLOW_EXPERIMENT_ID
MLFLOW_EXPERIMENT_IDS

{'4-50': '32', '5-50': '33', '6-75': '34'}

In [10]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [11]:
class ToFloat(object):
    def __init__(self):
      pass

    def __call__(self, tensor):
      return (tensor / 255.0).type(torch.float32)

TRANSFORM_IMG = transforms.Compose([
    transforms.Resize((224, 224)),
    ToFloat(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225] )
    ])

In [12]:
!ls /datasets

imagenet1k-4-50  imagenet1k-5-50  imagenet1k-6-75


In [13]:
with open("imagenet_classes.txt", "r") as f:
    categories = f.readlines()
    categories = [cat.rstrip("\n") for cat in categories]

In [15]:
if MODEL == "vit":
    srcnet = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1).eval().to(device)
elif MODEL == "resnet34":
    srcnet = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1).eval().to(device)
elif MODEL == "resnet50":
    srcnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).eval().to(device)
elif MODEL == "resnet18":
    srcnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).eval().to(device)
else:
    raise Exception("Unknown Source model")

In [None]:
for DATASET, MLFLOW_EXPERIMENT_ID in tqdm(MLFLOW_EXPERIMENT_IDS.items()):
    videos = [v for v in Path(f"/datasets/imagenet1k-{DATASET}").glob("*.mp4")]
    logging.info("Creating trainer")
    for video_path in tqdm(videos):
        video_path = str(video_path)
        vid = read_video(video_path, output_format="TCHW")[0]
        vid = TRANSFORM_IMG(vid).to(device)
        y_ix, _, crop_fraction, n_frames = parse_video_path_params(video_path)
        logging.info(crop_fraction)
        logging.info(n_frames)
        
        tgtnet_trainer = OverfitTrainer(categories=categories)
        tgtnet_trainer.set(
          pretrained_classifier=srcnet,
          num_classes=1000,
          confidence=CONFIDENCE,
          weight_decay=WEIGHT_DECAY,
          max_lr=MAX_LR,
          momentum=MOMENTUM,
        )
        tgtnet_trainer.model = tgtnet_trainer.model.to(device)

        logging.info("Starting experiment")
        with mlflow.start_run(experiment_id=MLFLOW_EXPERIMENT_ID) as run:
            mlflow.log_param("Crop fraction", crop_fraction)
            mlflow.log_param("Frames", n_frames)
            mlflow.log_param("Filename", video_path)
            mlflow.log_param("Source Model", MODEL)
            tgtnet_trainer.test(vid, [y_ix] * n_frames, active_run=run, hf_format=False)