# Demo of Basic Finetuning Steps

## Import models, dataset, and helper functions

In [1]:
import clip
import torch
from clip_model import CaptionModel
from dataset import InstagramDataset
from base_finetune import fine_tune
from transformers import GPT2Tokenizer

In [2]:
device = "cpu"

## Initialize Dataset / CLIP model

In [3]:
clip_model, preprocess = clip.load("ViT-B/32", device="cpu", jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [4]:
train_data = InstagramDataset(clip_model, preprocess, tokenizer)
validation_data = InstagramDataset(clip_model, preprocess, tokenizer, split="test")

## Define Model and Load StateDict
This is where you would make adjustments to the model

In [5]:
model = CaptionModel(10)
model.load_state_dict(torch.load("state_dicts/coco_weights.pt", map_location="cpu"))
model = model.eval()
model = model.to(device)

## Start Finetuning

In [None]:
train_loss, _ = fine_tune(model, train_data, epochs=1, batch_size=32, device=device)

Training Epoch 1
>>>


  0%|          | 0/887 [00:00<?, ?batch/s]

In [None]:
batches = range(len(train_loss))
import matplotlib.pyplot as plt
plt.plot(batches, train_loss)