In [None]:
import torch
from rich.progress import track

from model.glim import GLIM
from data.datamodule import GLIMDataModule
from torchmetrics.functional.classification import multiclass_accuracy

device = torch.device("cuda:0")
model = GLIM.load_from_checkpoint(
    "checkpoints/glim-zuco-epoch=199-step=49600.ckpt",
    map_location = device,
    strict = False,
    # evaluate_prompt_embed = 'src',
    # prompt_dropout_probs = (0.0, 0.1, 0.1),
    )
model.setup(stage='test')
dm = GLIMDataModule(data_path = './data/tmp/zuco_eeg_label_8variants.df',
                    eval_noise_input = False,
                    bsz_test = 24,
                    )
dm.setup(stage='test')

[Rank 0][GLIMDataModule] running `setup()`...
[Rank 0][GLIMDataModule] running `setup()`...Done! 😋😋😋


In [2]:
prefix = "The topic is about: "
template = ""
candidates = ["movie, good or bad", 
              "life experiences, relationship"]
candidates = [prefix + template + candi for candi in candidates]

results = []
with torch.no_grad():
    for batch in track(dm.test_dataloader()):
        eeg = batch['eeg'].to(device)
        eeg_mask = batch['mask'].to(device)
        prompts = batch['prompt'] # NOTE: [tuple('task'), tuple('dataset'), tuple('subject')] after collate
        raw_task_key = batch['raw task key']    # list[str]
        relation_label = batch['relation label']      # list[str]
        sentiment_label = batch['sentiment label']      # list[str]

        labels = []
        for t_key in raw_task_key:
            if t_key == 'task1':
                labels.append(0)
            else:
                labels.append(1)
        labels = torch.tensor(labels, device=device)
        

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            prob, gen_str = model.predict(eeg, eeg_mask, prompts, candidates, generate=True)

        for i in range(len(eeg)):
            results.append({'raw input text':   batch['raw input text'][i],
                            'label':        labels[i],
                            'prob':         prob[i],
                            'gen text':      gen_str[i],
                            'relation label': relation_label[i],
                            'sentiment label': sentiment_label[i],
                            })
            
probs = torch.stack([row['prob'] for row in results])
labels = torch.stack([row['label'] for row in results])
acc = multiclass_accuracy(probs, labels, num_classes=2, top_k=1, average='micro')
print('clip-like acc: ',acc.item())

Output()

clip-like acc:  0.9343296885490417


In [None]:
import pandas as pd
df_results = pd.DataFrame(results)
pd.to_pickle(df_results, 'data/tmp/glim_gen_results.pkl')

In [9]:
from torch.utils.data import DataLoader

probs1, probs2 = [], []
loader = DataLoader(results, batch_size=64, shuffle=False, drop_last=False)
for batch in track(loader):
    input_texts = batch['raw input text']
    gen_texts = batch['gen text']
    input_template = "To English: <MASK>"
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        prob1 = model.predict_text_embedding(input_texts, input_template, candidates)
        prob2 = model.predict_text_embedding(gen_texts, input_template, candidates)
        probs1.append(prob1)
        probs2.append(prob2)
probs1 = torch.cat(probs1, dim=0)
probs2 = torch.cat(probs2, dim=0)
acc1 = multiclass_accuracy(probs1, labels, num_classes=2, top_k=1, average='micro')
print('clip-like acc [raw input text]: ',acc1.item())
acc2 = multiclass_accuracy(probs2, labels, num_classes=2, top_k=1, average='micro')
print('clip-like acc [gen text]:       ',acc2.item())

Output()

clip-like acc [raw input text]:  0.91847825050354
clip-like acc [gen text]:        0.873641312122345


In [5]:
import torch
import transformers
eval_llm_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipe = transformers.pipeline(model = eval_llm_id,
                             model_kwargs = {"torch_dtype": torch.float16},
                             device_map = torch.device("cuda:0"),)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
import pandas as pd
df_results = pd.read_pickle('data/tmp/glim_gen_results.pkl')

In [9]:
torch.cuda.empty_cache()

In [10]:
# input_sentences = df_results['gen text'].tolist()
input_sentences = df_results['raw input text'].tolist()
# instructions = {"role": "system", 
#           "content": 
#             ("You task is to classify the most likely corpus source of the following sentence."
#              " Label '0' for 'movie review', '1' for 'personal biography."
#              " Please just output the integer label."
#             )}

instructions = {"role": "system", 
          "content": 
            ("You task is to classify the most likely topic of the following sentence."
             " Label '0' for 'movie review', '1' for 'personal biography."
             " Please just output the integer label."
            )}

# instructions = {"role": "system", 
#           "content": 
#             ("You task is to classify the most likely corpus source of the following sentence."
#              " Label '0' for 'Stanford Sentiment Treebank', '1' for 'Wikipedia relation extraction corpus."
#              " Please just output the integer label."
#             )}

messages = [[instructions, {"role": "user", "content": sen}] for sen in input_sentences]
inputs = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
terminators = [pipe.tokenizer.eos_token_id, 
               pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")]

pipe.tokenizer.pad_token_id = pipe.tokenizer.eos_token_id
pipe.tokenizer.padding_side = 'left'
# inputs_loader = DataLoader(llm_inputs, batch_size=32, shuffle=False, drop_last=False)

# llm_outputs = []
# for batch in track(inputs_loader):
with torch.no_grad():
  outputs = pipe(inputs, 
                batch_size = 16, 
                max_new_tokens = 4,
                eos_token_id = terminators,
                do_sample = True,
                num_beams = 2,
                pad_token_id = pipe.tokenizer.eos_token_id,
                )
    # llm_outputs.extend(batch_output)

In [11]:
n_correct = 0
total = df_results.shape[0]
for i in range(total):
    pred = int(outputs[i][0]['generated_text'][len(inputs[i]):]) # int, label id
    label = df_results.iloc[i]['label'].item()
    if pred == label:
        n_correct += 1
    gen_str = input_sentences[i]
    # print(f'label: {label}  pred: {pred}  gen_str: {gen_str}')
llm_acc = n_correct/total
print('llm-pred acc: ', llm_acc)

llm-pred acc:  0.8614130434782609
