In [141]:
from datasets import load_dataset
from torch.utils.data import Dataset,DataLoader,DistributedSampler
from transformers import AutoTokenizer, LlamaForCausalLM
import torch

In [142]:
class Dataset_alpaca(Dataset):
    def __init__(self,texts,prefix_prompt = "", suffix_prompt=""):
        """self.list_IDs	: list of strings (each string: utt key),
           self.labels      : dictionary (key: utt key, value: label integer)"""
        self.texts = texts
        self.prefix_prompt = prefix_prompt
        self.suffix_prompt = suffix_prompt

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        return self.prefix_prompt+self.texts[index]['instruction'] + self.texts[index]['input']+self.suffix_prompt

In [143]:
ds = load_dataset("/data8/wangzhiyong/project/LLM/llama_omni/a_datasets/alpaca/subset/data")
suffix_prompt = " Respond to the following input precisely and shortly in english. "
txtset = Dataset_alpaca(ds["train"], suffix_prompt=suffix_prompt)
alpaca_dl = DataLoader(txtset, batch_size=8, shuffle=True,drop_last = True,num_workers=0)
print(len(txtset))

Generating train split: 1000 examples [00:00, 108945.79 examples/s]

1000





In [144]:
for ind, a in enumerate(alpaca_dl):
    if ind==0:
        q = a
        break
q

['Describe how machine learning can be used to automate mundane tasks. Respond to the following input precisely and shortly in english. ',
 'Name 6 components of an artificial neural network Respond to the following input precisely and shortly in english. ',
 'Create a dialog between two people who are discussing a scientific phenomenonHydropower Respond to the following input precisely and shortly in english. ',
 'Explain the given concept in one sentence.Algorithmic complexity Respond to the following input precisely and shortly in english. ',
 'Answer this math problem.12/8 Respond to the following input precisely and shortly in english. ',
 'Name 3 countries that border France. Respond to the following input precisely and shortly in english. ',
 'Create a list of three tips for public speaking. Respond to the following input precisely and shortly in english. ',
 'Describe how quantum computers work. Respond to the following input precisely and shortly in english. ']

In [145]:
llama_dir="/data8/wangzhiyong/project/LLM/llama_omni/b_pretrained_models/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(llama_dir)
model = LlamaForCausalLM.from_pretrained(llama_dir)
tokenizer.pad_token_id = tokenizer.eos_token_id

In [146]:
txttoken = tokenizer(q, return_tensors="pt",padding=True)

In [147]:
txttoken.input_ids

tensor([[128000,  75885,   1268,   5780,   6975,    649,    387,   1511,    311,
          69711,  69782,   9256,     13,  40633,    311,    279,   2768,   1988,
          24559,    323,  20193,    304,  30063,     13,    220, 128009, 128009,
         128009],
        [128000,    678,    220,     21,   6956,    315,    459,  21075,  30828,
           4009,  40633,    311,    279,   2768,   1988,  24559,    323,  20193,
            304,  30063,     13,    220, 128009, 128009, 128009, 128009, 128009,
         128009],
        [128000,   4110,    264,   7402,   1990,   1403,   1274,    889,    527,
          25394,    264,  12624,  25885,  31916,   6861,   1223,  40633,    311,
            279,   2768,   1988,  24559,    323,  20193,    304,  30063,     13,
            220],
        [128000,    849,  21435,    279,   2728,   7434,    304,    832,  11914,
           9833,   7240,    292,  23965,  40633,    311,    279,   2768,   1988,
          24559,    323,  20193,    304,  30063,     13

In [148]:
generated_token = model.generate(**txttoken, max_length=500)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


In [149]:
generated_token

tensor([[128000,  75885,   1268,  ..., 128006, 128006, 128006],
        [128000,    678,    220,  ..., 128006, 128006, 128006],
        [128000,   4110,    264,  ...,    220,   1226,   1205],
        ...,
        [128000,    678,    220,  ..., 128006, 128006, 128006],
        [128000,   4110,    264,  ..., 128006, 128006, 128006],
        [128000,  75885,   1268,  ..., 128001, 128001, 128001]])

In [150]:
optxt = tokenizer.batch_decode([generated_token[i][len(txttoken.input_ids[i]):] for i in range(len(generated_token))], skip_special_tokens=True, clean_up_tokenization_spaces=False,)


In [151]:
optxt

['',
 '',
 ' I am a scientist, and I have been studying the effects of climate change on the world\'s oceans.  I am particularly interested in the phenomenon of ocean acidification, which is caused by the absorption of CO2 from the atmosphere.  I have been studying the impact of this phenomenon on marine life, and I am concerned about the potential consequences for the future of our planet.\n\nA colleague:  "I\'m not sure I agree with your perspective on this.  Ocean acidification is not just a problem for marine life, it\'s also a problem for human health.  We\'re already seeing the effects of rising CO2 levels on coral reefs and other marine ecosystems."\n\nScientist:  "I understand your point, but I still think that the impact of ocean acidification on marine life is more significant than the impact on human health.  For example, many marine species are already showing signs of adaptation to the changing CO2 levels, and some species are even evolving to thrive in these conditions.  