### Zero-shot cls with `eeg_emb`

In [None]:
import torch
from rich.progress import track

from model.glim import GLIM
from data.datamodule import GLIMDataModule
from torchmetrics.functional.classification import multiclass_accuracy

device = torch.device("cuda:0")
model = GLIM.load_from_checkpoint(
    "checkpoints/glim-zuco-epoch=199-step=49600.ckpt",
    map_location = device,
    strict = False,
    # evaluate_prompt_embed = 'src',
    # prompt_dropout_probs = (0.0, 0.1, 0.1),
    )
model.setup(stage='test')
dm = GLIMDataModule(data_path = './data/tmp/zuco_eeg_label_8variants.df',
                    eval_noise_input = False,
                    bsz_test = 24,
                    )
dm.setup(stage='test')

[Rank 0][GLIMDataModule] running `setup()`...
[Rank 0][GLIMDataModule] running `setup()`...Done! 😋😋😋


In [3]:
prefix = "Relation classification: "
template = "It is about <MASK>."
all_relations = ['awarding', 'education', 'employment',
                                'foundation', 'job title', 'nationality', 
                                'political affiliation','visit', 'marriage']
candidates = [prefix + template.replace("<MASK>", label) for label in all_relations]

results = []
with torch.no_grad():
    for batch in track(dm.test_dataloader()):
        eeg = batch['eeg'].to(device)
        eeg_mask = batch['mask'].to(device)
        prompts = batch['prompt'] # NOTE: [tuple('task'), tuple('dataset'), tuple('subject')] after collate
        raw_task_key = batch['raw task key']    # list[str]
        relation_label = batch['relation label']      # list[str]
        labels = []
        for relation in relation_label:
            if relation not in all_relations:
                labels.append(-1)
            else:
                labels.append(all_relations.index(relation))
        labels = torch.tensor(labels, device=device)
        

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            prob, gen_str = model.predict(eeg, eeg_mask, prompts, candidates, generate=False)

        for i in range(len(eeg)):
            results.append({'input text':   batch['raw input text'][i],
                            'label':        labels[i],
                            'prob':         prob[i],
                            'gen_str':      gen_str[i]
                            })
            
probs = torch.stack([row['prob'] for row in results])
labels = torch.stack([row['label'] for row in results])
acc1 = multiclass_accuracy(probs, labels, num_classes=9, top_k=1, ignore_index=-1, average='micro')
print('clip-like acc1: ',acc1.item())

acc3 = multiclass_accuracy(probs, labels, num_classes=9, top_k=3, ignore_index=-1, average='micro')
print('clip-like acc3: ',acc3.item())


Output()

clip-like acc1:  0.3244898021221161
clip-like acc3:  0.5714285969734192


## Load generated texts, labels, and raw texts

In [None]:
import pandas as pd
df_results = pd.read_pickle('data/tmp/glim_gen_results.pkl')

### Zero-shot cls

In [44]:
from torch.utils.data import DataLoader, Dataset

class BatchedDF(Dataset):
    def __init__(self, df: pd.DataFrame) -> None:
        self.raw_input_text = df['raw input text'].tolist()
        self.gen_text = df['gen text'].tolist()
        self.senti_label = df['sentiment label'].apply(lambda x: str(x)).tolist()
        self.rela_label = df['relation label'].apply(lambda x: str(x)).tolist()

    def __getitem__(self, idx):
        return {'raw input text': self.raw_input_text[idx],
                'gen text': self.gen_text[idx],
                'sentiment label': self.senti_label[idx],
                'relation label': self.rela_label[idx]
                }
    def __len__(self):
        return len(self.raw_input_text)
    

prefix = "Relation classification: "
template = "It is about <MASK>."
all_relations = ['awarding', 'education', 'employment',
                                'foundation', 'job title', 'nationality', 
                                'political affiliation','visit', 'marriage']
candidates = [prefix + template.replace("<MASK>", label) for label in all_relations]

probs1, probs2 = [], []
labels = []
dataset = BatchedDF(df_results)
loader = DataLoader(dataset, batch_size=64, shuffle=False, drop_last=False)
for batch in track(loader):
    input_texts = batch['raw input text']
    gen_texts = batch['gen text']
    # senti_label = batch['sentiment label']
    rela_label = batch['relation label']

    
    for relation in rela_label:
        if relation not in all_relations:
            labels.append(-1)
        else:
            labels.append(all_relations.index(relation))
    
    input_template = "To English: <MASK>."
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        prob1 = model.predict_text_embedding(input_texts, input_template, candidates)
        prob2 = model.predict_text_embedding(gen_texts, input_template, candidates)
        probs1.append(prob1)
        probs2.append(prob2)
probs1 = torch.cat(probs1, dim=0)
probs2 = torch.cat(probs2, dim=0)
labels = torch.tensor(labels, device=probs1.device)


Output()

In [46]:
acc1 = multiclass_accuracy(probs1, labels, num_classes=9, top_k=3, ignore_index=-1, average='micro')
print('clip-like acc [raw input text]: ',acc1.item())
acc2 = multiclass_accuracy(probs2, labels, num_classes=9, top_k=3, ignore_index=-1, average='micro')
print('clip-like acc [gen text]:       ',acc2.item())

clip-like acc [raw input text]:  0.4969387650489807
clip-like acc [gen text]:        0.4326530694961548


### use LLM

In [47]:
import torch
import transformers
eval_llm_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipe = transformers.pipeline(model = eval_llm_id,
                             model_kwargs = {"torch_dtype": torch.float16},
                             device_map = torch.device("cuda:1"),)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [62]:

# all_sentences = dataset.raw_input_text
all_sentences = dataset.gen_text
relation_labels = dataset.rela_label

input_sentences = []
filtered_labels = []
for sentence, label in zip(all_sentences, relation_labels):
  if label in all_relations:
    input_sentences.append(sentence)
    filtered_labels.append(label)


instructions = {"role": "system", 
          "content": 
            ("You task is relation extraction. Please choose the top-3 from the below 9 possible labels:\n"
             " awarding, education, employment, foundation, job title, nationality,"
             " political affiliation, visit, and marriage.\n"
             " Please just output the three most likely labels in descending order of probability,"
             " and separate them by commas. Do not output any other words!"
            )}


messages = [[instructions, {"role": "user", "content": sen}] for sen in input_sentences]
inputs = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
terminators = [pipe.tokenizer.eos_token_id, 
               pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")]

pipe.tokenizer.pad_token_id = pipe.tokenizer.eos_token_id
pipe.tokenizer.padding_side = 'left'
# inputs_loader = DataLoader(llm_inputs, batch_size=32, shuffle=False, drop_last=False)

# llm_outputs = []
# for batch in track(inputs_loader):
with torch.no_grad():
  outputs = pipe(inputs, 
                batch_size = 16, 
                max_new_tokens = 16,
                eos_token_id = terminators,
                do_sample = True,
                num_beams = 2,
                pad_token_id = pipe.tokenizer.eos_token_id,
                )
    # llm_outputs.extend(batch_output)

In [63]:
n_correct_top1 = 0
n_correct_top3 = 0
total = len(input_sentences)
for i in range(total):
    label = filtered_labels[i]
    pred_top3 = outputs[i][0]['generated_text'][len(inputs[i]):] # str, "label1,label2,label3"
    try:
        t1, t2, t3 = pred_top3.split(',')
        if label == t1:
            n_correct_top1 += 1
        print(f"Label: {label:<15}  (Pred) t1: {t1:<22} t2: {t2:<27} t3: {t3:<32}")
    except:
        print("👿"*3,pred_top3)

    if label in pred_top3:
        n_correct_top3 += 1

    # print(f'label: {label}  pred: {pred}  gen_str: {gen_str}')
llm_acc1 = n_correct_top1/total
llm_acc3 = n_correct_top3/total
print('llm-pred acc-top1: ', llm_acc1)
print('llm-pred acc-top3: ', llm_acc3)

Label: awarding         (Pred) t1: employment             t2:  awarding                   t3:  education                      
Label: awarding         (Pred) t1: employment             t2:  visit                      t3:  awarding                       
Label: awarding         (Pred) t1: employment             t2:  education                  t3:  job title                      
Label: awarding         (Pred) t1: employment             t2:  political affiliation      t3:  job title                      
Label: awarding         (Pred) t1: employment             t2:  awarding                   t3:  foundation                     
Label: awarding         (Pred) t1: employment             t2:  education                  t3:  awarding                       
Label: awarding         (Pred) t1: employment             t2:  education                  t3:  awarding                       
Label: education        (Pred) t1: employment             t2:  nationality                t3:  job title       