# Installations

In [41]:
!pip install transformers accelerate datasets sentencepiece scikit-learn
!pip install torch torchvision torchaudio

Collecting scikit-learn
  Using cached scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl.metadata (13 kB)
Collecting scipy>=1.6.0 (from scikit-learn)
  Using cached scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl.metadata (60 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Using cached threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Using cached scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl (11.0 MB)
Using cached scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl (23.1 MB)
Using cached threadpoolctl-3.5.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, scipy, scikit-learn
Successfully installed scikit-learn-1.5.2 scipy-1.14.1 threadpoolctl-3.5.0


In [1]:
import gc
import torch

import pandas as pd

from typing import List, Dict, Union

from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import T5TokenizerFast
from transformers import T5ForConditionalGeneration
from sklearn.model_selection import train_test_split
from transformers import Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


# Load dataset

In [1]:

df = pd.read_csv("hf://datasets/s-nlp/paradetox/train.tsv", sep="\t")
df

  from .autonotebook import tqdm as notebook_tqdm


Unnamed: 0,en_toxic_comment,en_neutral_comment
0,he had steel balls too !,he was brave too!
1,"dude should have been taken to api , he would ...",It would have been good if he went to api. He ...
2,"im not gonna sell the fucking picture , i just...","I'm not gonna sell the picture, i just want to..."
3,the garbage that is being created by cnn and o...,the news that is being created by cnn and othe...
4,the reason they dont exist is because neither ...,The reason they don't exist is because neither...
...,...,...
19739,when they do shit like this .,when they do stuff like this
19740,"but if saying "" fuck that group "" is much more...","but if saying"" that group is bad"" is much more..."
19741,"it hurts how judgemental assholes view them , ...",It hurts how judgemental that people view them...
19742,shit we probably literally blow that up in a w...,We probably litteralky blow that up in a week.


In [3]:
xydf = pd.DataFrame({'source': df["en_toxic_comment"].values.tolist(), 'target': df["en_neutral_comment"].values.tolist()})

In [4]:
xydf

Unnamed: 0,source,target
0,he had steel balls too !,he was brave too!
1,"dude should have been taken to api , he would ...",It would have been good if he went to api. He ...
2,"im not gonna sell the fucking picture , i just...","I'm not gonna sell the picture, i just want to..."
3,the garbage that is being created by cnn and o...,the news that is being created by cnn and othe...
4,the reason they dont exist is because neither ...,The reason they don't exist is because neither...
...,...,...
19739,when they do shit like this .,when they do stuff like this
19740,"but if saying "" fuck that group "" is much more...","but if saying"" that group is bad"" is much more..."
19741,"it hurts how judgemental assholes view them , ...",It hurts how judgemental that people view them...
19742,shit we probably literally blow that up in a w...,We probably litteralky blow that up in a week.


# Model checkpoints

In [5]:
checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [6]:
model_name = "t5-small"
tokenizer = T5TokenizerFast.from_pretrained(model_name)
df_train, df_test = train_test_split(xydf, test_size=300)
print(df_train.shape[0], df_test.shape[0])


x1 = tokenizer(df_train.source.tolist(), truncation=True)
y1 = tokenizer(df_train.target.tolist(), truncation=True)
x2 = tokenizer(df_test.source.tolist(), truncation=True)
y2 = tokenizer(df_test.target.tolist(), truncation=True)

19444 300


In [7]:
class PairsDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        assert idx < len(self.x['input_ids'])
        item = {key: val[idx] for key, val in self.x.items()}
        item['decoder_attention_mask'] = self.y['attention_mask'][idx]
        item['labels'] = self.y['input_ids'][idx]
        return item
    
    @property
    def n(self):
        return len(self.x['input_ids'])

    def __len__(self):
        return self.n # * 2
    
train_dataset = PairsDataset(x1, y1)
test_dataset = PairsDataset(x2, y2)
len(train_dataset), len(test_dataset)

(19444, 300)

In [9]:
train_dataloader = DataLoader(train_dataset, batch_size=4, drop_last=True, shuffle=True, num_workers=1)
test_dataloader = DataLoader(test_dataset, batch_size=4, drop_last=True, shuffle=True, num_workers=1)

In [11]:

class DataCollatorWithPadding:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        batch = self.tokenizer.pad(
            features,
            padding=True,
        )
        ybatch = self.tokenizer.pad(
            {'input_ids': batch['labels'], 'attention_mask': batch['decoder_attention_mask']},
            padding=True,
        ) 
        batch['labels'] = ybatch['input_ids']
        batch['decoder_attention_mask'] = ybatch['attention_mask']
        
        return {k: torch.tensor(v) for k, v in batch.items()}

In [12]:
model = T5ForConditionalGeneration.from_pretrained(model_name)

In [14]:


if torch.backends.mps.is_available():
    device = torch.device('mps')  # Use MPS (Metal Performance Shaders) on Apple Silicon
else:
    device = torch.device('cpu')  # Fallback to CPU if MPS is not available

model.to(device)


print(device)


mps


In [16]:
save_name = 'models/t5-baseline'

In [20]:

training_args = TrainingArguments(
    output_dir=save_name,             # output directory
    overwrite_output_dir=True,
    num_train_epochs=3,               # total number of training epochs
    per_device_train_batch_size=4,    # batch size per device during training
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=8,     # batch size for evaluation
    warmup_steps=300,                 # number of warmup steps for learning rate scheduler
    weight_decay=0,                   # strength of weight decay
    learning_rate=3e-5,
    logging_dir='./logs',             # directory for storing logs
    logging_steps=100,
    eval_steps=100,
    evaluation_strategy='steps',      # evaluation strategy (e.g., 'steps', 'epoch')
    save_total_limit=1,               # limit the number of saved checkpoints
    save_steps=5000,                  # save checkpoints every 5000 steps
)


In [21]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [22]:
trainer = Trainer(
    model=model,                         
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=test_dataset,           # evaluation dataset
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [23]:
gc.collect()
torch.cuda.empty_cache() 

In [24]:
trainer.train()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  0%|          | 0/3645 [00:00<?, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  3%|▎         | 100/3645 [00:57<24:23,  2.42it/s]

{'loss': 4.9263, 'grad_norm': 36.59382629394531, 'learning_rate': 9.999999999999999e-06, 'epoch': 0.08}


                                                  
  3%|▎         | 100/3645 [01:01<24:23,  2.42it/s]

{'eval_loss': 3.6027963161468506, 'eval_runtime': 4.3918, 'eval_samples_per_second': 68.309, 'eval_steps_per_second': 8.653, 'epoch': 0.08}


  5%|▌         | 200/3645 [01:46<23:19,  2.46it/s]  

{'loss': 2.7939, 'grad_norm': 16.23082733154297, 'learning_rate': 1.9999999999999998e-05, 'epoch': 0.16}


                                                  
  5%|▌         | 200/3645 [01:48<23:19,  2.46it/s]

{'eval_loss': 1.1944864988327026, 'eval_runtime': 1.5296, 'eval_samples_per_second': 196.134, 'eval_steps_per_second': 24.844, 'epoch': 0.16}


  8%|▊         | 300/3645 [02:32<22:55,  2.43it/s]

{'loss': 1.628, 'grad_norm': 3.375939130783081, 'learning_rate': 3e-05, 'epoch': 0.25}


                                                  
  8%|▊         | 300/3645 [02:34<22:55,  2.43it/s]

{'eval_loss': 0.8312837481498718, 'eval_runtime': 1.6175, 'eval_samples_per_second': 185.473, 'eval_steps_per_second': 23.493, 'epoch': 0.25}


 11%|█         | 400/3645 [03:19<22:52,  2.36it/s]

{'loss': 1.1628, 'grad_norm': 3.3008182048797607, 'learning_rate': 2.9103139013452915e-05, 'epoch': 0.33}


                                                  
 11%|█         | 400/3645 [03:21<22:52,  2.36it/s]

{'eval_loss': 0.7325024604797363, 'eval_runtime': 1.5506, 'eval_samples_per_second': 193.479, 'eval_steps_per_second': 24.507, 'epoch': 0.33}


 14%|█▎        | 500/3645 [04:06<24:27,  2.14it/s]

{'loss': 1.0124, 'grad_norm': 1.49667227268219, 'learning_rate': 2.820627802690583e-05, 'epoch': 0.41}


                                                  
 14%|█▎        | 500/3645 [04:08<24:27,  2.14it/s]

{'eval_loss': 0.6901258230209351, 'eval_runtime': 1.4448, 'eval_samples_per_second': 207.636, 'eval_steps_per_second': 26.301, 'epoch': 0.41}


 16%|█▋        | 600/3645 [04:52<21:41,  2.34it/s]

{'loss': 0.9491, 'grad_norm': 1.7835198640823364, 'learning_rate': 2.7309417040358744e-05, 'epoch': 0.49}


                                                  
 16%|█▋        | 600/3645 [04:54<21:41,  2.34it/s]

{'eval_loss': 0.6658235788345337, 'eval_runtime': 1.5315, 'eval_samples_per_second': 195.884, 'eval_steps_per_second': 24.812, 'epoch': 0.49}


 19%|█▉        | 700/3645 [05:36<19:49,  2.48it/s]

{'loss': 0.9784, 'grad_norm': 1.6177520751953125, 'learning_rate': 2.641255605381166e-05, 'epoch': 0.58}


                                                  
 19%|█▉        | 700/3645 [05:38<19:49,  2.48it/s]

{'eval_loss': 0.6498381495475769, 'eval_runtime': 1.581, 'eval_samples_per_second': 189.754, 'eval_steps_per_second': 24.036, 'epoch': 0.58}


 22%|██▏       | 800/3645 [06:20<20:44,  2.29it/s]

{'loss': 0.9288, 'grad_norm': 1.9504826068878174, 'learning_rate': 2.5515695067264577e-05, 'epoch': 0.66}


                                                  
 22%|██▏       | 800/3645 [06:22<20:44,  2.29it/s]

{'eval_loss': 0.6420889496803284, 'eval_runtime': 1.5274, 'eval_samples_per_second': 196.413, 'eval_steps_per_second': 24.879, 'epoch': 0.66}


 25%|██▍       | 900/3645 [07:05<19:28,  2.35it/s]

{'loss': 0.8678, 'grad_norm': 2.0610134601593018, 'learning_rate': 2.461883408071749e-05, 'epoch': 0.74}


                                                  
 25%|██▍       | 900/3645 [07:06<19:28,  2.35it/s]

{'eval_loss': 0.6335931420326233, 'eval_runtime': 1.495, 'eval_samples_per_second': 200.662, 'eval_steps_per_second': 25.417, 'epoch': 0.74}


 27%|██▋       | 1000/3645 [07:48<18:11,  2.42it/s]

{'loss': 0.8808, 'grad_norm': 1.2415040731430054, 'learning_rate': 2.3721973094170406e-05, 'epoch': 0.82}


                                                   
 27%|██▋       | 1000/3645 [07:49<18:11,  2.42it/s]

{'eval_loss': 0.6248595714569092, 'eval_runtime': 1.5039, 'eval_samples_per_second': 199.479, 'eval_steps_per_second': 25.267, 'epoch': 0.82}


 30%|███       | 1100/3645 [08:31<17:04,  2.48it/s]

{'loss': 0.8553, 'grad_norm': 1.8695703744888306, 'learning_rate': 2.282511210762332e-05, 'epoch': 0.91}


                                                   
 30%|███       | 1100/3645 [08:32<17:04,  2.48it/s]

{'eval_loss': 0.6185200214385986, 'eval_runtime': 1.6246, 'eval_samples_per_second': 184.662, 'eval_steps_per_second': 23.391, 'epoch': 0.91}


 33%|███▎      | 1200/3645 [09:15<20:36,  1.98it/s]

{'loss': 0.8864, 'grad_norm': 1.810689926147461, 'learning_rate': 2.1928251121076235e-05, 'epoch': 0.99}


                                                   
 33%|███▎      | 1200/3645 [09:16<20:36,  1.98it/s]

{'eval_loss': 0.6155784726142883, 'eval_runtime': 1.4706, 'eval_samples_per_second': 203.999, 'eval_steps_per_second': 25.84, 'epoch': 0.99}


 36%|███▌      | 1300/3645 [10:01<16:55,  2.31it/s]

{'loss': 0.896, 'grad_norm': 1.4834133386611938, 'learning_rate': 2.1031390134529146e-05, 'epoch': 1.07}


                                                   
 36%|███▌      | 1300/3645 [10:03<16:55,  2.31it/s]

{'eval_loss': 0.6090807318687439, 'eval_runtime': 1.5812, 'eval_samples_per_second': 189.725, 'eval_steps_per_second': 24.032, 'epoch': 1.07}


 38%|███▊      | 1400/3645 [10:44<17:16,  2.17it/s]

{'loss': 0.8413, 'grad_norm': 1.2618285417556763, 'learning_rate': 2.013452914798206e-05, 'epoch': 1.15}


                                                   
 38%|███▊      | 1400/3645 [10:46<17:16,  2.17it/s]

{'eval_loss': 0.6078020334243774, 'eval_runtime': 1.6519, 'eval_samples_per_second': 181.613, 'eval_steps_per_second': 23.004, 'epoch': 1.15}


 41%|████      | 1500/3645 [11:29<14:28,  2.47it/s]

{'loss': 0.8222, 'grad_norm': 1.6590497493743896, 'learning_rate': 1.9237668161434976e-05, 'epoch': 1.23}


                                                   
 41%|████      | 1500/3645 [11:31<14:28,  2.47it/s]

{'eval_loss': 0.6034665107727051, 'eval_runtime': 1.4979, 'eval_samples_per_second': 200.274, 'eval_steps_per_second': 25.368, 'epoch': 1.23}


 44%|████▍     | 1600/3645 [12:14<14:10,  2.40it/s]

{'loss': 0.7975, 'grad_norm': 1.4050182104110718, 'learning_rate': 1.834080717488789e-05, 'epoch': 1.32}


                                                   
 44%|████▍     | 1600/3645 [12:16<14:10,  2.40it/s]

{'eval_loss': 0.6031021475791931, 'eval_runtime': 1.5101, 'eval_samples_per_second': 198.667, 'eval_steps_per_second': 25.164, 'epoch': 1.32}


 47%|████▋     | 1700/3645 [12:58<12:57,  2.50it/s]

{'loss': 0.8395, 'grad_norm': 1.9426754713058472, 'learning_rate': 1.7443946188340808e-05, 'epoch': 1.4}


                                                   
 47%|████▋     | 1700/3645 [12:59<12:57,  2.50it/s]

{'eval_loss': 0.5985274910926819, 'eval_runtime': 1.4682, 'eval_samples_per_second': 204.326, 'eval_steps_per_second': 25.881, 'epoch': 1.4}


 49%|████▉     | 1800/3645 [13:39<11:57,  2.57it/s]

{'loss': 0.8118, 'grad_norm': 1.3921880722045898, 'learning_rate': 1.6547085201793723e-05, 'epoch': 1.48}


                                                   
 49%|████▉     | 1800/3645 [13:41<11:57,  2.57it/s]

{'eval_loss': 0.59911048412323, 'eval_runtime': 1.4613, 'eval_samples_per_second': 205.296, 'eval_steps_per_second': 26.004, 'epoch': 1.48}


 52%|█████▏    | 1900/3645 [14:22<11:40,  2.49it/s]

{'loss': 0.8572, 'grad_norm': 2.0433287620544434, 'learning_rate': 1.5650224215246637e-05, 'epoch': 1.56}


                                                   
 52%|█████▏    | 1900/3645 [14:23<11:40,  2.49it/s]

{'eval_loss': 0.5952101349830627, 'eval_runtime': 1.4727, 'eval_samples_per_second': 203.703, 'eval_steps_per_second': 25.802, 'epoch': 1.56}


 55%|█████▍    | 2000/3645 [15:04<11:25,  2.40it/s]

{'loss': 0.8194, 'grad_norm': 1.5059988498687744, 'learning_rate': 1.4753363228699552e-05, 'epoch': 1.65}


                                                   
 55%|█████▍    | 2000/3645 [15:05<11:25,  2.40it/s]

{'eval_loss': 0.592438817024231, 'eval_runtime': 1.5575, 'eval_samples_per_second': 192.621, 'eval_steps_per_second': 24.399, 'epoch': 1.65}


 58%|█████▊    | 2100/3645 [15:48<10:44,  2.40it/s]

{'loss': 0.8288, 'grad_norm': 1.969853162765503, 'learning_rate': 1.3856502242152466e-05, 'epoch': 1.73}


                                                   
 58%|█████▊    | 2100/3645 [15:50<10:44,  2.40it/s]

{'eval_loss': 0.5914881229400635, 'eval_runtime': 1.673, 'eval_samples_per_second': 179.316, 'eval_steps_per_second': 22.713, 'epoch': 1.73}


 60%|██████    | 2200/3645 [16:36<09:54,  2.43it/s]

{'loss': 0.8036, 'grad_norm': 1.667931318283081, 'learning_rate': 1.2959641255605381e-05, 'epoch': 1.81}


                                                   
 60%|██████    | 2200/3645 [16:38<09:54,  2.43it/s]

{'eval_loss': 0.5909345746040344, 'eval_runtime': 1.5457, 'eval_samples_per_second': 194.09, 'eval_steps_per_second': 24.585, 'epoch': 1.81}


 63%|██████▎   | 2300/3645 [17:20<08:58,  2.50it/s]

{'loss': 0.8317, 'grad_norm': 1.520100712776184, 'learning_rate': 1.2062780269058296e-05, 'epoch': 1.89}


                                                   
 63%|██████▎   | 2300/3645 [17:21<08:58,  2.50it/s]

{'eval_loss': 0.5897814035415649, 'eval_runtime': 1.5829, 'eval_samples_per_second': 189.522, 'eval_steps_per_second': 24.006, 'epoch': 1.89}


 66%|██████▌   | 2400/3645 [18:07<08:54,  2.33it/s]

{'loss': 0.8079, 'grad_norm': 1.8058706521987915, 'learning_rate': 1.1165919282511212e-05, 'epoch': 1.97}


                                                   
 66%|██████▌   | 2400/3645 [18:08<08:54,  2.33it/s]

{'eval_loss': 0.5880868434906006, 'eval_runtime': 1.5297, 'eval_samples_per_second': 196.119, 'eval_steps_per_second': 24.842, 'epoch': 1.97}


 69%|██████▊   | 2500/3645 [18:50<07:30,  2.54it/s]

{'loss': 0.7949, 'grad_norm': 1.9725841283798218, 'learning_rate': 1.0269058295964126e-05, 'epoch': 2.06}


                                                   
 69%|██████▊   | 2500/3645 [18:51<07:30,  2.54it/s]

{'eval_loss': 0.5880577564239502, 'eval_runtime': 1.4898, 'eval_samples_per_second': 201.364, 'eval_steps_per_second': 25.506, 'epoch': 2.06}


 71%|███████▏  | 2600/3645 [19:32<06:53,  2.53it/s]

{'loss': 0.7655, 'grad_norm': 1.5877556800842285, 'learning_rate': 9.372197309417041e-06, 'epoch': 2.14}


                                                   
 71%|███████▏  | 2600/3645 [19:34<06:53,  2.53it/s]

{'eval_loss': 0.5877042412757874, 'eval_runtime': 1.4946, 'eval_samples_per_second': 200.716, 'eval_steps_per_second': 25.424, 'epoch': 2.14}


 74%|███████▍  | 2700/3645 [20:15<06:16,  2.51it/s]

{'loss': 0.7832, 'grad_norm': 1.4154342412948608, 'learning_rate': 8.475336322869956e-06, 'epoch': 2.22}


                                                   
 74%|███████▍  | 2700/3645 [20:16<06:16,  2.51it/s]

{'eval_loss': 0.5879653096199036, 'eval_runtime': 1.4759, 'eval_samples_per_second': 203.27, 'eval_steps_per_second': 25.747, 'epoch': 2.22}


 77%|███████▋  | 2800/3645 [20:59<05:51,  2.40it/s]

{'loss': 0.8238, 'grad_norm': 1.9583219289779663, 'learning_rate': 7.578475336322871e-06, 'epoch': 2.3}


                                                   
 77%|███████▋  | 2800/3645 [21:00<05:51,  2.40it/s]

{'eval_loss': 0.5846158862113953, 'eval_runtime': 1.5404, 'eval_samples_per_second': 194.749, 'eval_steps_per_second': 24.668, 'epoch': 2.3}


 80%|███████▉  | 2900/3645 [21:41<05:03,  2.46it/s]

{'loss': 0.8005, 'grad_norm': 1.3142147064208984, 'learning_rate': 6.681614349775786e-06, 'epoch': 2.39}


                                                   
 80%|███████▉  | 2900/3645 [21:43<05:03,  2.46it/s]

{'eval_loss': 0.5828062295913696, 'eval_runtime': 1.5362, 'eval_samples_per_second': 195.285, 'eval_steps_per_second': 24.736, 'epoch': 2.39}


 82%|████████▏ | 3000/3645 [22:24<04:24,  2.44it/s]

{'loss': 0.7918, 'grad_norm': 1.5047531127929688, 'learning_rate': 5.7847533632287e-06, 'epoch': 2.47}


                                                   
 82%|████████▏ | 3000/3645 [22:25<04:24,  2.44it/s]

{'eval_loss': 0.5838493704795837, 'eval_runtime': 1.543, 'eval_samples_per_second': 194.426, 'eval_steps_per_second': 24.627, 'epoch': 2.47}


 85%|████████▌ | 3100/3645 [23:06<03:39,  2.48it/s]

{'loss': 0.7866, 'grad_norm': 1.759975552558899, 'learning_rate': 4.887892376681614e-06, 'epoch': 2.55}


                                                   
 85%|████████▌ | 3100/3645 [23:07<03:39,  2.48it/s]

{'eval_loss': 0.5853977799415588, 'eval_runtime': 1.4473, 'eval_samples_per_second': 207.279, 'eval_steps_per_second': 26.255, 'epoch': 2.55}


 88%|████████▊ | 3200/3645 [23:48<02:51,  2.60it/s]

{'loss': 0.8237, 'grad_norm': 1.5023605823516846, 'learning_rate': 3.991031390134529e-06, 'epoch': 2.63}


                                                   
 88%|████████▊ | 3200/3645 [23:50<02:51,  2.60it/s]

{'eval_loss': 0.5841083526611328, 'eval_runtime': 1.4519, 'eval_samples_per_second': 206.619, 'eval_steps_per_second': 26.172, 'epoch': 2.63}


 91%|█████████ | 3300/3645 [24:32<02:24,  2.39it/s]

{'loss': 0.7917, 'grad_norm': 1.3554308414459229, 'learning_rate': 3.0941704035874443e-06, 'epoch': 2.72}


                                                   
 91%|█████████ | 3300/3645 [24:33<02:24,  2.39it/s]

{'eval_loss': 0.5838132500648499, 'eval_runtime': 1.4636, 'eval_samples_per_second': 204.976, 'eval_steps_per_second': 25.964, 'epoch': 2.72}


 93%|█████████▎| 3400/3645 [25:14<01:38,  2.48it/s]

{'loss': 0.778, 'grad_norm': 1.810310959815979, 'learning_rate': 2.1973094170403584e-06, 'epoch': 2.8}


                                                   
 93%|█████████▎| 3400/3645 [25:15<01:38,  2.48it/s]

{'eval_loss': 0.5835533142089844, 'eval_runtime': 1.4732, 'eval_samples_per_second': 203.638, 'eval_steps_per_second': 25.794, 'epoch': 2.8}


 96%|█████████▌| 3500/3645 [25:56<00:58,  2.50it/s]

{'loss': 0.7737, 'grad_norm': 1.590498447418213, 'learning_rate': 1.3004484304932737e-06, 'epoch': 2.88}


                                                   
 96%|█████████▌| 3500/3645 [25:57<00:58,  2.50it/s]

{'eval_loss': 0.5833763480186462, 'eval_runtime': 1.4743, 'eval_samples_per_second': 203.491, 'eval_steps_per_second': 25.776, 'epoch': 2.88}


 99%|█████████▉| 3600/3645 [26:38<00:17,  2.52it/s]

{'loss': 0.806, 'grad_norm': 1.4529438018798828, 'learning_rate': 4.035874439461884e-07, 'epoch': 2.96}


                                                   
 99%|█████████▉| 3600/3645 [26:40<00:17,  2.52it/s]

{'eval_loss': 0.5829160809516907, 'eval_runtime': 1.5051, 'eval_samples_per_second': 199.323, 'eval_steps_per_second': 25.248, 'epoch': 2.96}


100%|██████████| 3645/3645 [26:59<00:00,  2.25it/s]

{'train_runtime': 1620.0174, 'train_samples_per_second': 36.007, 'train_steps_per_second': 2.25, 'train_loss': 1.0346248799390754, 'epoch': 3.0}





TrainOutput(global_step=3645, training_loss=1.0346248799390754, metrics={'train_runtime': 1620.0174, 'train_samples_per_second': 36.007, 'train_steps_per_second': 2.25, 'total_flos': 399857836228608.0, 'train_loss': 1.0346248799390754, 'epoch': 2.999382843036412})

In [25]:
trainer.evaluate()

100%|██████████| 38/38 [00:01<00:00, 21.86it/s]


{'eval_loss': 0.5828280448913574,
 'eval_runtime': 1.8275,
 'eval_samples_per_second': 164.163,
 'eval_steps_per_second': 20.794,
 'epoch': 2.999382843036412}

In [26]:
model.eval()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [31]:
inputs = tokenizer('The internal policy of the fucking Trump is stupid.', return_tensors='pt')
inputs = {k: v.to(device) for k, v in inputs.items()}
for t in model.generate(**inputs, num_return_sequences=10, do_sample=False, num_beams=10):
    print(tokenizer.decode(t, skip_special_tokens=True))

The internal policy of Trump is stupid.
The internal policy of the Trump is stupid.
The internal policy of Trump is stupid
The internal policy of Trump is not good.
The internal policy of the Trump is stupid
The internal policy of Trump is bad.
The internal policy of the Trump is not good.
The internal policy of the Trump is bad.
The internal policy of Trump is not good
The internal policy of the Trump is not good
