In [9]:
import torch
import transformers
from torch.utils.data import DataLoader

from model import PromptLRN, PromptOptim
from dataset import Dataset
from config import cfg

In [12]:
def top_k_acc(pred, y, top_k=1):
    if top_k == 1:
        acc = (pred.reshape(-1,) == y).sum() / y.shape[0] * 100
        return acc
    else:
        corr = 0
        for p, t in zip(pred, y):
            if t in p:
                corr += 1
        acc = corr / y.shape[0] * 100
        return acc 

def evaluate_acc(dataset, optimmodule, topk):
    testset = Dataset('sun397', 16, train=False)
    testloader = DataLoader(testset, batch_size=100)
    model = optimmodule.model
    model.eval()
    ys = torch.tensor(testset.df.labels.values)
    preds = torch.tensor([])
    # evaluation iteration
    with torch.no_grad():
        for step, pixel in enumerate(testloader):
            logits = model(pixel)
            pred = torch.topk(logits, topk=topk, dim=1).indices
            preds = torch.cat([preds, pred], dim=0)
        acc = top_k_acc(preds, ys, top_k = topk)
    return acc

### 1. Training Text Prompt Learner

#### 1.1. Training

In [None]:
opt_sun397_text = PromptOptim(cfg, 'sun397', kshot=16, type='text')
opt_sun397_text.train()

In [None]:
opt_eurosat_text = PromptOptim(cfg, 'eurosat', kshot=16, type='text')
opt_eurosat_text.train()

#### 1.2. Evaluation

In [13]:
acc1 = evaluate_acc('sun397', opt_sun397_text, topk=5)
print(acc1)

RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

In [None]:
acc2 = evaluate_acc('eurosat', opt_eurosat_text, topk=3)
print(acc2)

### 2. Training Visual+Text prompt Learner

#### 2.1. Training

In [None]:
opt_sun397_vt = PromptOptim(cfg, 'sun397', kshot=16, type='text+vision')
opt_sun397_vt.train()

In [None]:
opt_eurosat_vt = PromptOptim(cfg, 'eurosat', kshot=16, type='text+vision')
opt_eurosat_vt.train()

#### 2.2. Evaluation

In [None]:
acc3 = evaluate_acc('sun397', opt_sun397_vt, topk=5)
print(acc3)

In [None]:
acc4 = evaluate_acc('eurosat', opt_eurosat_vt, topk=3)
print(acc4)