In [8]:
import torch
import transformers
from torch.utils.data import DataLoader

from model import PromptLRN, VCPromptLRN, PromptOptim
from dataset import Dataset
from config import cfg

In [3]:
def top_k_acc(pred, y, top_k=1):
    if top_k == 1:
        acc = (pred.reshape(-1,) == y).sum() / y.shape[0] * 100
        return acc
    else:
        corr = 0
        for p, t in zip(pred, y):
            if t in p:
                corr += 1
        acc = corr / y.shape[0] * 100
        return acc 

In [5]:
torch.manual_seed(2022)
device = torch.device('cpu')

## 1. EuroSAT dataset

### 1.1. Training & Evaluation of 16 shot Learning

#### 1.1.1. Text Prompt Learning

In [None]:
# training
proptim = PromptOptim(cfg=cfg, device=device, dataset='eurosat', kshot=16, type='text', start_epoch=0)
proptim.train()

In [10]:
# evaluation
############## config ##############
device = torch.device('cpu')
dataset = 'eurosat'
epoch = 200
type = 'text'
kshot = 16
topk = 1
####################################

# set model and evaluation dataloader 
testset = Dataset(dataset, kshot, train=False)
testloader = DataLoader(testset, batch_size=100)
if type == 'text':
    model = PromptLRN(testset.labels, cfg, device)
else:
    model = VCPromptLRN(testset.labels, cfg, device)
# load trained 
state_dict = torch.load('./ckpt/{}_promptlearn_{}/{}_shot/model_epoch{}.pt'.format(dataset, type, kshot, epoch))
model.load_state_dict(state_dict())
model.eval().to(device)
ys = torch.tensor(testset.df.labels.values)
preds = torch.tensor([])
# evaluation iteration
with torch.no_grad():
    for step, pixel in enumerate(testloader):
        logits = model(pixel.to(device))
        pred = torch.topk(logits, k=topk, dim=1).indices
        preds = torch.cat([preds, pred], dim=0)
        if (step+1) % 50:
            print('{} images evaluated'.format(step * len(testloader)))
    acc = top_k_acc(preds, ys, top_k = topk)

print('top {} Accuracy on {} dataset with {} shot setting : {}%'.format(topk, dataset, kshot, acc))

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


0 images evaluated
54 images evaluated
108 images evaluated
162 images evaluated
216 images evaluated
270 images evaluated
324 images evaluated
378 images evaluated
432 images evaluated
486 images evaluated
540 images evaluated
594 images evaluated
648 images evaluated
702 images evaluated
756 images evaluated
810 images evaluated
864 images evaluated
918 images evaluated
972 images evaluated
1026 images evaluated
1080 images evaluated
1134 images evaluated
1188 images evaluated
1242 images evaluated
1296 images evaluated
1350 images evaluated
1404 images evaluated
1458 images evaluated
1512 images evaluated
1566 images evaluated
1620 images evaluated
1674 images evaluated
1728 images evaluated
1782 images evaluated
1836 images evaluated
1890 images evaluated
1944 images evaluated
1998 images evaluated
2052 images evaluated
2106 images evaluated
2160 images evaluated
2214 images evaluated
2268 images evaluated
2322 images evaluated
2376 images evaluated
2430 images evaluated
2484 image

#### 1.1.2. Text Prompt + Visual Prompt Learning

In [None]:
# training
proptim = PromptOptim(cfg=cfg, device=device, dataset='eurosat', kshot=16, type='text+vision', start_epoch=0)
proptim.train()

In [11]:
# evaluation
############## config ##############
device = torch.device('cpu')
dataset = 'eurosat'
epoch = 200
type = 'text+vision'
kshot = 16
topk = 1
####################################

# set model and evaluation dataloader 
testset = Dataset(dataset, kshot, train=False)
testloader = DataLoader(testset, batch_size=100)
if type == 'text':
    model = PromptLRN(testset.labels, cfg, device)
else:
    model = VCPromptLRN(testset.labels, cfg, device)
# load trained 
state_dict = torch.load('./ckpt/{}_promptlearn_{}/{}_shot/model_epoch{}.pt'.format(dataset, type, kshot, epoch))
model.load_state_dict(state_dict())
model.eval().to(device)
ys = torch.tensor(testset.df.labels.values)
preds = torch.tensor([])
# evaluation iteration
with torch.no_grad():
    for step, pixel in enumerate(testloader):
        logits = model(pixel.to(device))
        pred = torch.topk(logits, k=topk, dim=1).indices
        preds = torch.cat([preds, pred], dim=0)
        if (step+1) % 50:
            print('{} images evaluated'.format(step * len(testloader)))
    acc = top_k_acc(preds, ys, top_k = topk)

print('top {} Accuracy on {} dataset with {} shot setting : {}%'.format(topk, dataset, kshot, acc))

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


0 images evaluated
54 images evaluated
108 images evaluated
162 images evaluated
216 images evaluated
270 images evaluated
324 images evaluated
378 images evaluated
432 images evaluated
486 images evaluated
540 images evaluated
594 images evaluated
648 images evaluated
702 images evaluated
756 images evaluated
810 images evaluated
864 images evaluated
918 images evaluated
972 images evaluated
1026 images evaluated
1080 images evaluated
1134 images evaluated
1188 images evaluated
1242 images evaluated
1296 images evaluated
1350 images evaluated
1404 images evaluated
1458 images evaluated
1512 images evaluated
1566 images evaluated
1620 images evaluated
1674 images evaluated
1728 images evaluated
1782 images evaluated
1836 images evaluated
1890 images evaluated
1944 images evaluated
1998 images evaluated
2052 images evaluated
2106 images evaluated
2160 images evaluated
2214 images evaluated
2268 images evaluated
2322 images evaluated
2376 images evaluated
2430 images evaluated
2484 image

### 1.2. Training & Evaluation of 8-shot learning

#### 1.2.1. Text Prompt Learning

In [None]:
# training
proptim = PromptOptim(cfg=cfg, device=device, dataset='eurosat', kshot=8, type='text', start_epoch=0)
proptim.train()

In [12]:
# evaluation
############## config ##############
device = torch.device('cpu')
dataset = 'eurosat'
epoch = 100
type = 'text'
kshot = 8
topk = 1
####################################

# set model and evaluation dataloader 
testset = Dataset(dataset, kshot, train=False)
testloader = DataLoader(testset, batch_size=100)
if type == 'text':
    model = PromptLRN(testset.labels, cfg, device)
else:
    model = VCPromptLRN(testset.labels, cfg, device)
# load trained 
state_dict = torch.load('./ckpt/{}_promptlearn_{}/{}_shot/model_epoch{}.pt'.format(dataset, type, kshot, epoch))
model.load_state_dict(state_dict())
model.eval().to(device)
ys = torch.tensor(testset.df.labels.values)
preds = torch.tensor([])
# evaluation iteration
with torch.no_grad():
    for step, pixel in enumerate(testloader):
        logits = model(pixel.to(device))
        pred = torch.topk(logits, k=topk, dim=1).indices
        preds = torch.cat([preds, pred], dim=0)
        if (step+1) % 50:
            print('{} images evaluated'.format(step * len(testloader)))
    acc = top_k_acc(preds, ys, top_k = topk)

print('top {} Accuracy on {} dataset with {} shot setting : {}%'.format(topk, dataset, kshot, acc))

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


0 images evaluated
54 images evaluated
108 images evaluated
162 images evaluated
216 images evaluated
270 images evaluated
324 images evaluated
378 images evaluated
432 images evaluated
486 images evaluated
540 images evaluated
594 images evaluated
648 images evaluated
702 images evaluated
756 images evaluated
810 images evaluated
864 images evaluated
918 images evaluated
972 images evaluated
1026 images evaluated
1080 images evaluated
1134 images evaluated
1188 images evaluated
1242 images evaluated
1296 images evaluated
1350 images evaluated
1404 images evaluated
1458 images evaluated
1512 images evaluated
1566 images evaluated
1620 images evaluated
1674 images evaluated
1728 images evaluated
1782 images evaluated
1836 images evaluated
1890 images evaluated
1944 images evaluated
1998 images evaluated
2052 images evaluated
2106 images evaluated
2160 images evaluated
2214 images evaluated
2268 images evaluated
2322 images evaluated
2376 images evaluated
2430 images evaluated
2484 image

#### 1.2.2. Text Prompt + Visual Prompt Learning

In [None]:
proptim = PromptOptim(cfg=cfg, device=device, dataset='eurosat', kshot=8, type='text+vision', start_epoch=0)
proptim.train()

In [13]:
# evaluation
############## config ##############
device = torch.device('cpu')
dataset = 'eurosat'
epoch = 100
type = 'text+vision'
kshot = 8
topk = 1
####################################

# set model and evaluation dataloader 
testset = Dataset(dataset, kshot, train=False)
testloader = DataLoader(testset, batch_size=100)
if type == 'text':
    model = PromptLRN(testset.labels, cfg, device)
else:
    model = VCPromptLRN(testset.labels, cfg, device)
# load trained 
state_dict = torch.load('./ckpt/{}_promptlearn_{}/{}_shot/model_epoch{}.pt'.format(dataset, type, kshot, epoch))
model.load_state_dict(state_dict())
model.eval().to(device)
ys = torch.tensor(testset.df.labels.values)
preds = torch.tensor([])
# evaluation iteration
with torch.no_grad():
    for step, pixel in enumerate(testloader):
        logits = model(pixel.to(device))
        pred = torch.topk(logits, k=topk, dim=1).indices
        preds = torch.cat([preds, pred], dim=0)
        if (step+1) % 50:
            print('{} images evaluated'.format(step * len(testloader)))
    acc = top_k_acc(preds, ys, top_k = topk)

print('top {} Accuracy on {} dataset with {} shot setting : {}%'.format(topk, dataset, kshot, acc))

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


0 images evaluated
54 images evaluated
108 images evaluated
162 images evaluated
216 images evaluated
270 images evaluated
324 images evaluated
378 images evaluated
432 images evaluated
486 images evaluated
540 images evaluated
594 images evaluated
648 images evaluated
702 images evaluated
756 images evaluated
810 images evaluated
864 images evaluated
918 images evaluated
972 images evaluated
1026 images evaluated
1080 images evaluated
1134 images evaluated
1188 images evaluated
1242 images evaluated
1296 images evaluated
1350 images evaluated
1404 images evaluated
1458 images evaluated
1512 images evaluated
1566 images evaluated
1620 images evaluated
1674 images evaluated
1728 images evaluated
1782 images evaluated
1836 images evaluated
1890 images evaluated
1944 images evaluated
1998 images evaluated
2052 images evaluated
2106 images evaluated
2160 images evaluated
2214 images evaluated
2268 images evaluated
2322 images evaluated
2376 images evaluated
2430 images evaluated
2484 image

In [15]:
import easydict
from easydict import EasyDict as edict

In [16]:
result = edict()

# dataset
result.eurosat = edict()
result.fgvcaircraft = edict()

# k_shot
result.eurosat.shot8 = edict()
result.eurosat.shot16 = edict()

# bakcbone
result.eurosat.shot8.clip_vit_b32 = edict()
result.eurosat.shot16.clip_vit_b32 = edict()

# approach
result.eurosat.shot8.clip_vit_b32.text_prompt = edict()
result.eurosat.shot16.clip_vit_b32.visual_text_prompt = edict()

# hyperparams & accuracy
## 16 shot
result.eurosat.shot16.clip_vit_b32.text_prompt.v1 = edict()
result.eurosat.shot16.clip_vit_b32.text_prompt.v1.hyperparams = {'epochs':200, 'ctx_len':16}
result.eurosat.shot16.clip_vit_b32.text_prompt.v1.acc = 62.59

result.eurosat.shot16.clip_vit_b32.visual_text_prompt.v1 = edict()
result.eurosat.shot16.clip_vit_b32.visual_text_prompt.v1.hyperparams = {'epochs':200, 'ctx_len':16, 'v_ctx_len':5}
result.eurosat.shot16.clip_vit_b32.visual_text_prompt.v1.acc = 51.2

## 8 shot
result.eurosat.shot16.clip_vit_b32.text_prompt.v1 = edict()
result.eurosat.shot16.clip_vit_b32.text_prompt.v1.hyperparams = {'epochs':200, 'ctx_len':16}
result.eurosat.shot16.clip_vit_b32.text_prompt.v1.acc = 62.59

result.eurosat.shot16.clip_vit_b32.visual_text_prompt.v1 = edict()
result.eurosat.shot16.clip_vit_b32.visual_text_prompt.v1.hyperparams = {'epochs':200, 'ctx_len':16, 'v_ctx_len':5}
result.eurosat.shot16.clip_vit_b32.visual_text_prompt.v1.acc = 51.2


In [17]:
result.eurosat.shot8.clip_vit_b32.text_prompt.hyperparams

{'model': {'ctx_len': 16, 'v_ctx_len': 5, 't_h_dim': 512, 'v_h_dim': 768},
 'train': {'n_epochs': 100,
  'batch_size': 32,
  'k_shot': 16,
  'base_lr': 1e-05,
  'max_lr': 0.002,
  'pct_start': 0.01},
 'eval': {}}