### Zero-shot cls with `eeg_emb`

In [None]:
import torch
from rich.progress import track

from model.glim import GLIM
from data.datamodule import GLIMDataModule
from torchmetrics.functional.classification import multiclass_accuracy

device = torch.device("cuda:0")
model = GLIM.load_from_checkpoint(
    "checkpoints/glim-zuco-epoch=199-step=49600.ckpt",
    map_location = device,
    strict = False,
    # evaluate_prompt_embed = 'src',
    # prompt_dropout_probs = (0.0, 0.1, 0.1),
    )
model.setup(stage='test')
dm = GLIMDataModule(data_path = './data/tmp/zuco_eeg_label_8variants.df',
                    eval_noise_input = False,
                    bsz_test = 24,
                    )
dm.setup(stage='test')

[Rank 0][GLIMDataModule] running `setup()`...
[Rank 0][GLIMDataModule] running `setup()`...Done! 😋😋😋


In [7]:
prefix = "Sentiment classification: "
template = "It is <MASK>."
all_sentiments = ['negative', 'neutral', 'positive']
# candidates = ['bad or boring',
#                 'normal or ordinary', 
#                 'good or great',]
candidates = [prefix + template.replace("<MASK>", label) for label in all_sentiments]

results = []
with torch.no_grad():
    for batch in track(dm.test_dataloader()):
        eeg = batch['eeg'].to(device)
        eeg_mask = batch['mask'].to(device)
        prompts = batch['prompt'] # NOTE: [tuple('task'), tuple('dataset'), tuple('subject')] after collate
        raw_task_key = batch['raw task key']    # list[str]
        sentiment_label = batch['sentiment label']      # list[str]

        labels = []
        for sentiment in sentiment_label:
            if sentiment not in all_sentiments:
                labels.append(-1)
            else:
                labels.append(all_sentiments.index(sentiment))
        labels = torch.tensor(labels, device=device)
        

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            prob, gen_str = model.predict(eeg, eeg_mask, prompts, candidates, generate=False)

        for i in range(len(eeg)):
            results.append({'input text':   batch['raw input text'][i],
                            'label':        labels[i],
                            'prob':         prob[i],
                            'gen_str':      gen_str[i]
                            })
            
probs = torch.stack([row['prob'] for row in results])
labels = torch.stack([row['label'] for row in results])
acc1 = multiclass_accuracy(probs, labels, num_classes=3, top_k=1, ignore_index=-1, average='micro')
print('clip-like acc1: ',acc1.item())



Output()

clip-like acc1:  0.4292565882205963


## Load generated texts, labels and raw texts

In [None]:
import pandas as pd
df_results = pd.read_pickle('data/tmp/glim_gen_results.pkl')

### Zero-shot cls on text embeddings

In [3]:
from torch.utils.data import DataLoader, Dataset


class BatchedDF(Dataset):
    def __init__(self, df: pd.DataFrame) -> None:
        self.raw_input_text = df['raw input text'].tolist()
        self.gen_text = df['gen text'].tolist()
        self.senti_label = df['sentiment label'].apply(lambda x: str(x)).tolist()
        self.rela_label = df['relation label'].apply(lambda x: str(x)).tolist()

    def __getitem__(self, idx):
        return {'raw input text': self.raw_input_text[idx],
                'gen text': self.gen_text[idx],
                'sentiment label': self.senti_label[idx],
                'relation label': self.rela_label[idx]
                }
    def __len__(self):
        return len(self.raw_input_text)
    

prefix = "Sentiment: "
template = "It is <MASK>."
all_sentiments = ['negative', 'neutral', 'positive']
candidates = [prefix + template.replace("<MASK>", label) for label in all_sentiments]

df_filtered = df_results[df_results['sentiment label'] != 'nan']
dataset = BatchedDF(df_filtered)
loader = DataLoader(dataset, batch_size=64, shuffle=False, drop_last=False)
probs1, probs2 = [], []
labels = []
for batch in track(loader):
    input_texts = batch['raw input text']
    gen_texts = batch['gen text']
    senti_label = batch['sentiment label']
    # rela_label = batch['relation label']

    for sentiment in senti_label:
        if sentiment not in all_sentiments:
            labels.append(-1)
        else:
            labels.append(all_sentiments.index(sentiment))
    
    input_template = "Sentiment classification: <MASK>."
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        prob1 = model.predict_text_embedding(input_texts, input_template, candidates)
        prob2 = model.predict_text_embedding(gen_texts, input_template, candidates)
        probs1.append(prob1)
        probs2.append(prob2)
probs1 = torch.cat(probs1, dim=0)
probs2 = torch.cat(probs2, dim=0)
labels = torch.tensor(labels, device=probs1.device)


In [35]:
acc1 = multiclass_accuracy(probs1, labels, num_classes=3, top_k=1, ignore_index=-1, average='micro')
print('clip-like acc [raw input text]: ',acc1.item())
acc2 = multiclass_accuracy(probs2, labels, num_classes=3, top_k=1, ignore_index=-1, average='micro')
print('clip-like acc [gen text]:       ',acc2.item())

clip-like acc [raw input text]:  0.4556354880332947
clip-like acc [gen text]:        0.3956834673881531


### use LLM

In [2]:
import torch
import transformers
eval_llm_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipe = transformers.pipeline(model = eval_llm_id,
                             model_kwargs = {"torch_dtype": torch.float16},
                             device_map = torch.device("cuda:1"),)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
# input_sentences = dataset.raw_input_text
input_sentences = dataset.gen_text
filtered_labels = dataset.senti_label

instructions = {"role": "system", 
          "content": 
            ("You task is sentiment classification. Please pick the most likely label from:\n"
             " negative, neutral and positive.\n"
             " Please just output your predicted label do not output any other words!"
            )}


messages = [[instructions, {"role": "user", "content": sen}] for sen in input_sentences]
inputs = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
terminators = [pipe.tokenizer.eos_token_id, 
               pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")]

pipe.tokenizer.pad_token_id = pipe.tokenizer.eos_token_id
pipe.tokenizer.padding_side = 'left'
# inputs_loader = DataLoader(llm_inputs, batch_size=32, shuffle=False, drop_last=False)

# llm_outputs = []
# for batch in track(inputs_loader):
with torch.no_grad():
  outputs = pipe(inputs, 
                batch_size = 16, 
                max_new_tokens = 8,
                eos_token_id = terminators,
                do_sample = True,
                num_beams = 2,
                pad_token_id = pipe.tokenizer.eos_token_id,
                )
    # llm_outputs.extend(batch_output)

In [7]:
n_correct = 0
total = len(input_sentences)
for i in range(total):
    label = filtered_labels[i]
    pred = outputs[i][0]['generated_text'][len(inputs[i]):] # str, the pred "sentiment label"
    assert pred in all_sentiments

    if label == pred:
        n_correct += 1

    # print(f'label: {label}  pred: {pred}  gen_str: {gen_str}')
llm_acc = n_correct/total
print('llm-pred acc: ', llm_acc)

llm-pred acc:  0.39568345323741005
