In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset, DownloadMode
import numpy as np
from collections import defaultdict
import os
from tqdm import tqdm
import pickle

In [20]:
class TopicDataset(Dataset):
    def __init__(self, path_to_topic_pkl : str, is_positive, tokenizer, max_length : int):
        with open(path_to_topic_pkl, 'rb') as f:
            data = pickle.load(f)
            self.topic_texts = data[list(data.keys())[0]]['positive']
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.topic_texts)

    def __getitem__(self, index):
        text = self.topic_texts[index]
        print(text)
        encoding = self.tokenizer(
            text,
            max_length = self.max_length,
            truncation = True,
            padding = 'max_length',
            return_tensors = 'pt'
        )
        
        return {
            'input_ids':encoding['input_ids'].squeeze(0),
            'attention_mask':encoding['attention_mask'].squeeze(0),
            'idx':index
        }

In [None]:
MODEL_NAME = "openlm-research/open_llama_3b"

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, 
    cache_dir='/home/jovyan/rusakov/dim_lm/cache',
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    cache_dir='/home/jovyan/rusakov/dim_lm/cache',
)
model.eval()
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [4]:
import pickle
topic = 'City district.pkl'
path = f'/home/jovyan/zorin/GCS/datasets/openai/{topic}'
print(path)
f'{path}/{topic}'

with open(path, 'rb') as f:
    print(pickle.load(f))

/home/jovyan/zorin/GCS/datasets/openai/City district.pkl
