In [27]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import pandas as pd

In [28]:
model_id = "meta-llama/Llama-3.2-3B"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16, 
    device_map="cuda"
)

model.to(device)

Downloading shards: 100%|██████████| 2/2 [02:32<00:00, 76.42s/it] 
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.47s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((3072,), eps=1e-05)
    (rotary_emb

In [6]:
splits = {'validation': 'data/validation-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/jhu-clsp/jfleg/" + splits["validation"])

df

Unnamed: 0,sentence,corrections
0,So I think we can not live if old people could...,[So I think we would not be alive if our ances...
1,For not use car .,"[Not for use with a car . , Do not use in the ..."
2,Here was no promise of morning except that we ...,"[Here was no promise of morning , except that ..."
3,Thus even today sex is considered as the least...,"[Thus , even today , sex is considered as the ..."
4,image you salf you are wark in factory just to...,[Imagine yourself you are working in factory j...
...,...,...
750,The government also should try to reduce the s...,[The government should also try to reduce the ...
751,Alot of memories with enogh time to remember w...,"[A lot of memories , with enough time to remem..."
752,Sceene of violence can affect on them .,[A scene of violence can have an effect on the...
753,While the communities in general have reckoned...,[The communities in general have reckoned that...


In [7]:
sentences = df["sentence"].tolist()[:10]
sentences

['So I think we can not live if old people could not find siences and tecnologies and they did not developped . ',
 'For not use car . ',
 'Here was no promise of morning except that we looked up through the trees we saw how low the forest had swung . ',
 'Thus even today sex is considered as the least important topic in many parts of India . ',
 'image you salf you are wark in factory just to do one thing like pot taire on car if they fire you you will destroy , becouse u dont know more than pot taire in car . ',
 'They draw the consumers , like me , to purchase this great product with all these amazing ingredients and all that but actually they just sometimes make something up just to increase their sales . ',
 'I want to talk about nocive or bad products like alcohol , hair spray and cigarrets . ',
 'For example they can play football whenever they want but the olders can not . ',
 'It figures Diana Krall wearing a Rolex watch and has a text that suggests that if the reader wants to

In [8]:
tokenizer.pad_token = tokenizer.eos_token
tokenized_sentences = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=128)
tokenized_sentences = {k: v.to(device) for k, v in tokenized_sentences.items()}

In [20]:
with torch.no_grad():
    output = model(
        **tokenized_sentences, 
        max_length=4,
        output_hidden_states=True,
        return_dict=True
    )

In [26]:
print(output.hidden_states[-1].shape)

# get the hidden states for the last layer
last_hidden_states = output.hidden_states[-1]
print(last_hidden_states[:, -1, :].shape)

torch.Size([10, 56, 2048])
torch.Size([10, 2048])


In [22]:
print(output.hidden_states[-1])

tensor([[[ 1.5498, -1.5039,  2.7441,  ..., -0.6343,  0.6006, -1.0107],
         [ 1.7227,  4.1797,  0.4827,  ..., -3.5352, -4.4727, -0.0364],
         [ 2.9551,  3.5996,  0.8955,  ...,  0.5044, -5.7305, -0.7026],
         ...,
         [-0.5122,  3.4785, -0.7607,  ..., -0.3057,  2.7070,  0.2009],
         [-0.4985,  3.4707, -0.7036,  ..., -0.2144,  2.6152,  0.3516],
         [-0.5576,  3.4648, -0.6382,  ..., -0.0870,  2.5430,  0.4092]],

        [[ 1.5498, -1.5039,  2.7441,  ..., -0.6343,  0.6006, -1.0107],
         [ 0.9595,  2.0977,  2.1934,  ..., -5.8633, -5.9609,  0.2430],
         [ 0.8931,  3.8398,  1.7012,  ..., -3.0137, -5.3867, -1.1514],
         ...,
         [-0.5483,  3.4922,  1.5713,  ...,  0.7861,  0.5498,  0.2510],
         [-0.4812,  3.4629,  1.6123,  ...,  0.8403,  0.5264,  0.3218],
         [-0.4551,  3.5176,  1.6816,  ...,  0.9189,  0.4802,  0.2788]],

        [[ 1.5498, -1.5039,  2.7441,  ..., -0.6343,  0.6006, -1.0107],
         [ 0.8232,  3.6406,  2.4062,  ..., -5