## Setup

In [None]:
!pip install -Uqq wandb

!pip install clip@git+https://github.com/openai/CLIP.git
!pip install -Uqq bellem@git+https://github.com/bdsaglam/bellem.git
!pip show bellem
!pip show fastmtl

In [None]:
import json
import os
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms

In [None]:
import os
from getpass import getpass
os.environ['KAGGLE_USERNAME'] = 'bdsaglam'
os.environ['KAGGLE_KEY'] = 'TODO'

In [None]:
import wandb

wandb_username = "bdsaglam"
wandb_token = "TODO"

!wandb login --relogin $wandb_token

## Data

In [None]:
from fastai.data.all import untar_data
from fastai.vision.all import URLs

imagenette_path = Path(untar_data(URLs.IMAGENETTE_320)).absolute()
print(imagenette_path)

In [None]:
imagenette_sketch_path =  Path('./imagenette-sketch').absolute()

!rm -rf $imagenette_sketch_path 
!mkdir $imagenette_sketch_path 
!(cd $imagenette_sketch_path && kaggle datasets download -d bdsaglam/imagenette-sketch-classification && unzip imagenette-sketch-classification.zip)
!(rm ./imagenette-sketch/imagenette-sketch-classification.zip)

print(imagenette_sketch_path)

## Train & Evaluate

In [1]:
# PLACEHOLDER FOR TRAINING SCRIPT

In [None]:
config = {
  "seed": 42,
  "data": {
    "imagenet": {
      "path": str(imagenette_path),
      "batch_size": 16
    },
    "imagenet_sketch": {
      "path": str(imagenette_sketch_path),
      "valid_pct": 0.3,
      "batch_size": 16
    }
  },
  "clip": {
    "model_name": "RN50",
    "prec": "fp32"
  },
  "coop": {
    "class_specific_contexts": True,
    "n_ctx": 16
  },
  "train": {
    "n_epoch": 50,
    "lr": 1e-4
  },
  "wandb": {
    "mode": "offline",
    "entity": "bdsaglam",
    "project": "coop"
  }
}

sweep_config = {
    "metric": {"name": "accuracy", "goal": "maximize"},
    "method": "bayes",
    "parameters": {
        "clip.model_name": {"values": ["RN50"]},
        "coop.n_ctx": {"values": [1, 8, 16, 32]},
        "train.lr": {"max": 1e-2, "min": 1e-4}
    },
    "count": 20
}


In [None]:
from types import SimpleNamespace

with open('./config.json', 'w') as f:
    json.dump(config, f)

with open('./sweep-config.json', 'w') as f:
    json.dump(sweep_config, f)

args = SimpleNamespace(**{'cfg': './config.json', 'sweep_cfg': './sweep-config.json'})
main(args)