# Demo of Basic Finetuning Steps

## Import models, dataset, and helper functions

In [1]:
import clip
import torch
from clip_model import CaptionModel
from dataset import InstagramDataset
from base_finetune import fine_tune
from transformers import GPT2Tokenizer

In [2]:
device = "cpu"

## Initialize Dataset / CLIP model

In [3]:
clip_model, preprocess = clip.load("ViT-B/32", device="cpu", jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [4]:
train_data = InstagramDataset(clip_model, preprocess, tokenizer)
validation_data = InstagramDataset(clip_model, preprocess, tokenizer, split="test")

## Define Model and Load StateDict
This is where you would make adjustments to the model

In [5]:
model = CaptionModel(10)
model.load_state_dict(torch.load("state_dicts/coco_weights.pt", map_location="cpu"))
model = model.eval()
model = model.to(device)

## Start Finetuning

In [None]:
train_loss, _ = fine_tune(model, train_data, epochs=1, batch_size=32, device=device)

Training Epoch 1
>>>


  0%|          | 0/887 [00:00<?, ?batch/s]

Batch 1/887


  0%|          | 2/887 [03:31<26:55:50, 109.55s/batch]

Batch 2/887


  0%|          | 3/887 [06:40<35:43:30, 145.49s/batch]

Batch 3/887


  0%|          | 4/887 [08:31<32:23:41, 132.07s/batch]

Batch 4/887


  1%|          | 5/887 [10:38<31:56:12, 130.35s/batch]

Batch 5/887


  1%|          | 6/887 [26:46<101:31:22, 414.85s/batch]

Batch 6/887


  1%|          | 7/887 [29:48<82:51:43, 338.98s/batch] 

Batch 7/887


  1%|          | 8/887 [34:21<77:37:34, 317.92s/batch]

Batch 8/887


  1%|          | 9/887 [42:05<88:39:10, 363.50s/batch]

Batch 9/887


  1%|          | 10/887 [43:54<69:24:03, 284.88s/batch]

Batch 10/887


  1%|          | 11/887 [1:12:27<175:39:53, 721.91s/batch]

Batch 11/887


  1%|▏         | 12/887 [2:25:05<444:22:24, 1828.28s/batch]

Batch 12/887


  1%|▏         | 13/887 [2:27:01<317:53:41, 1309.41s/batch]

Batch 13/887


  2%|▏         | 14/887 [2:45:50<304:20:06, 1254.99s/batch]

Batch 14/887


  2%|▏         | 15/887 [2:47:16<218:36:00, 902.48s/batch] 

Batch 15/887


  2%|▏         | 16/887 [2:48:39<158:40:46, 655.85s/batch]

Batch 16/887


  2%|▏         | 17/887 [2:50:01<116:49:55, 483.44s/batch]

Batch 17/887


  2%|▏         | 18/887 [2:51:23<87:35:29, 362.86s/batch] 

Batch 18/887


  2%|▏         | 19/887 [2:52:46<67:12:05, 278.72s/batch]

Batch 19/887


  2%|▏         | 20/887 [2:54:09<52:55:52, 219.78s/batch]

Batch 20/887


  2%|▏         | 21/887 [2:55:31<42:58:38, 178.66s/batch]

Batch 21/887


  2%|▏         | 22/887 [2:56:54<36:00:41, 149.87s/batch]

Batch 22/887


  3%|▎         | 23/887 [2:58:16<31:05:57, 129.58s/batch]

Batch 23/887


  3%|▎         | 24/887 [2:59:39<27:41:42, 115.53s/batch]

Batch 24/887


  3%|▎         | 25/887 [3:01:01<25:15:59, 105.52s/batch]

Batch 25/887


  3%|▎         | 26/887 [3:02:23<23:33:50, 98.53s/batch] 

Batch 26/887


  3%|▎         | 27/887 [3:03:46<22:22:53, 93.69s/batch]

Batch 27/887


  3%|▎         | 28/887 [3:05:08<21:31:20, 90.20s/batch]

Batch 28/887


  3%|▎         | 29/887 [3:22:31<89:38:37, 376.13s/batch]

Batch 29/887


  3%|▎         | 30/887 [3:23:53<68:31:26, 287.85s/batch]

Batch 30/887


  3%|▎         | 31/887 [3:25:15<53:46:41, 226.17s/batch]

Batch 31/887


  4%|▎         | 32/887 [3:26:38<43:28:22, 183.04s/batch]

Batch 32/887


  4%|▎         | 33/887 [3:28:00<36:14:46, 152.79s/batch]

Batch 33/887


  4%|▍         | 34/887 [3:29:22<31:10:50, 131.59s/batch]

Batch 34/887


  4%|▍         | 35/887 [3:30:44<27:38:34, 116.80s/batch]

Batch 35/887


  4%|▍         | 36/887 [3:32:07<25:11:32, 106.57s/batch]

Batch 36/887


  4%|▍         | 37/887 [3:33:30<23:28:36, 99.43s/batch] 

Batch 37/887


  4%|▍         | 38/887 [3:35:03<23:01:28, 97.63s/batch]

Batch 38/887


  4%|▍         | 39/887 [3:36:43<23:08:33, 98.25s/batch]

Batch 39/887


  5%|▍         | 40/887 [3:38:13<22:32:54, 95.84s/batch]

Batch 40/887


  5%|▍         | 41/887 [3:39:36<21:36:18, 91.94s/batch]

Batch 41/887


  5%|▍         | 42/887 [3:40:58<20:53:44, 89.02s/batch]

Batch 42/887


  5%|▍         | 43/887 [3:42:21<20:24:45, 87.07s/batch]

Batch 43/887


  5%|▍         | 44/887 [3:43:43<20:03:19, 85.65s/batch]

Batch 44/887


  5%|▌         | 45/887 [3:45:06<19:50:09, 84.81s/batch]

Batch 45/887


  5%|▌         | 46/887 [3:46:29<19:41:41, 84.31s/batch]

Batch 46/887


In [None]:
batches = range(len(train_loss))
import matplotlib.pyplot as plt
plt.plot(batches, train_loss)