In [None]:
! pip install transformers torch sentencepiece accelerate

In [4]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import Dataset
import pandas as pd
from sklearn.model_selection import train_test_split

# Load and preprocess the dataset
file_path = 'quantized_coordinates.csv'
coordinates_df = pd.read_csv(file_path)
coordinates_df['sequence'] = coordinates_df['y_quant'].astype(str) + ' ' + coordinates_df['x_quant'].astype(str)

# Prepare input-output pairs for training
def prepare_data(df, input_len=3):
    input_sequences = []
    output_sequences = []
    for i in range(len(df) - input_len):
        input_seq = ' '.join(df['sequence'].iloc[i:i+input_len])
        output_seq = ' '.join(df['sequence'].iloc[i:i+input_len+1])
        input_sequences.append(input_seq)
        output_sequences.append(output_seq)
    return input_sequences, output_sequences

input_seqs, output_seqs = prepare_data(coordinates_df)

# Split the data into train and validation sets (80% train, 20% validation)
train_inputs, val_inputs, train_outputs, val_outputs = train_test_split(
    input_seqs, output_seqs, test_size=0.2, random_state=42
)

# Custom Dataset Class
class CoordinateDataset(Dataset):
    def __init__(self, inputs, outputs, tokenizer, max_len, device):
        self.inputs = inputs
        self.outputs = outputs
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.device = device

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_seq = self.inputs[idx]
        output_seq = self.outputs[idx]
        
        inputs = self.tokenizer(input_seq, max_length=self.max_len, padding='max_length', truncation=True, return_tensors="pt")
        outputs = self.tokenizer(output_seq, max_length=self.max_len, padding='max_length', truncation=True, return_tensors="pt")
        
        return {
            'input_ids': inputs.input_ids.flatten().to(self.device),  # Move tensors to MPS
            'attention_mask': inputs.attention_mask.flatten().to(self.device),  # Move tensors to MPS
            'labels': outputs.input_ids.flatten().to(self.device)  # Move tensors to MPS
        }

# Initialize model and tokenizer
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Check if MPS is available and move model to MPS
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model = model.to(device)

# Dataset parameters
MAX_LEN = 50
BATCH_SIZE = 8

# Create train and eval datasets
train_dataset = CoordinateDataset(train_inputs, train_outputs, tokenizer, max_len=MAX_LEN, device=device)
eval_dataset = CoordinateDataset(val_inputs, val_outputs, tokenizer, max_len=MAX_LEN, device=device)

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=10,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="steps",  # Evaluate every X steps
    eval_steps=50,  # Evaluation every 50 steps
    save_steps=500,  # Save checkpoint every 500 steps
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

# Inference: Generate a sequence given the first few points
def generate_sequence(model, tokenizer, input_sequence, max_length=50):
    inputs = tokenizer(input_sequence, return_tensors="pt").to(device)
    output = model.generate(inputs.input_ids, max_length=max_length)
    return tokenizer.decode(output[0], skip_special_tokens=True)




  1%|          | 10/1830 [41:45<126:41:01, 250.58s/it]
  1%|          | 10/1470 [00:03<08:00,  3.04it/s]

{'loss': 9.3745, 'grad_norm': 168.86508178710938, 'learning_rate': 4.965986394557823e-05, 'epoch': 0.07}


  1%|▏         | 20/1470 [00:07<08:56,  2.71it/s]

{'loss': 4.251, 'grad_norm': 41.80757522583008, 'learning_rate': 4.931972789115647e-05, 'epoch': 0.14}


  2%|▏         | 30/1470 [00:11<09:00,  2.66it/s]

{'loss': 1.7881, 'grad_norm': 5.389647960662842, 'learning_rate': 4.89795918367347e-05, 'epoch': 0.2}


  3%|▎         | 40/1470 [00:14<08:19,  2.86it/s]

{'loss': 1.2663, 'grad_norm': 4.654216766357422, 'learning_rate': 4.8639455782312926e-05, 'epoch': 0.27}


  3%|▎         | 50/1470 [00:18<08:10,  2.89it/s]

{'loss': 1.0263, 'grad_norm': 3.3365211486816406, 'learning_rate': 4.8299319727891155e-05, 'epoch': 0.34}



  3%|▎         | 50/1470 [00:21<08:10,  2.89it/s]

{'eval_loss': 0.5145986676216125, 'eval_runtime': 2.9669, 'eval_samples_per_second': 98.756, 'eval_steps_per_second': 12.471, 'epoch': 0.34}


  4%|▍         | 60/1470 [00:25<08:36,  2.73it/s]

{'loss': 0.8014, 'grad_norm': 2.545766592025757, 'learning_rate': 4.795918367346939e-05, 'epoch': 0.41}


  5%|▍         | 70/1470 [00:29<11:04,  2.11it/s]

{'loss': 0.6587, 'grad_norm': 10.885637283325195, 'learning_rate': 4.761904761904762e-05, 'epoch': 0.48}


  5%|▌         | 80/1470 [00:33<08:29,  2.73it/s]

{'loss': 0.4999, 'grad_norm': 1.4502846002578735, 'learning_rate': 4.7278911564625856e-05, 'epoch': 0.54}


  6%|▌         | 90/1470 [00:38<11:42,  1.96it/s]

{'loss': 0.4794, 'grad_norm': 2.1317436695098877, 'learning_rate': 4.6938775510204086e-05, 'epoch': 0.61}


  7%|▋         | 100/1470 [00:42<08:35,  2.66it/s]

{'loss': 0.44, 'grad_norm': 2.0872983932495117, 'learning_rate': 4.6598639455782315e-05, 'epoch': 0.68}



  7%|▋         | 100/1470 [00:45<08:35,  2.66it/s]

{'eval_loss': 0.2637560963630676, 'eval_runtime': 2.7907, 'eval_samples_per_second': 104.99, 'eval_steps_per_second': 13.258, 'epoch': 0.68}


  7%|▋         | 110/1470 [00:49<08:37,  2.63it/s]

{'loss': 0.3719, 'grad_norm': 1.7253849506378174, 'learning_rate': 4.625850340136055e-05, 'epoch': 0.75}


  8%|▊         | 120/1470 [00:53<07:55,  2.84it/s]

{'loss': 0.3549, 'grad_norm': 3.203784942626953, 'learning_rate': 4.591836734693878e-05, 'epoch': 0.82}


  9%|▉         | 130/1470 [00:57<08:50,  2.53it/s]

{'loss': 0.3148, 'grad_norm': 2.4245100021362305, 'learning_rate': 4.557823129251701e-05, 'epoch': 0.88}


 10%|▉         | 140/1470 [01:00<07:37,  2.91it/s]

{'loss': 0.2879, 'grad_norm': 1.9083384275436401, 'learning_rate': 4.523809523809524e-05, 'epoch': 0.95}


 10%|█         | 150/1470 [01:04<07:40,  2.86it/s]

{'loss': 0.306, 'grad_norm': 1.9761618375778198, 'learning_rate': 4.4897959183673474e-05, 'epoch': 1.02}



 10%|█         | 150/1470 [01:07<07:40,  2.86it/s]

{'eval_loss': 0.2166891098022461, 'eval_runtime': 2.9859, 'eval_samples_per_second': 98.129, 'eval_steps_per_second': 12.392, 'epoch': 1.02}


 11%|█         | 160/1470 [01:10<08:24,  2.60it/s]

{'loss': 0.2809, 'grad_norm': 1.0483455657958984, 'learning_rate': 4.4557823129251704e-05, 'epoch': 1.09}


 12%|█▏        | 170/1470 [01:14<07:57,  2.72it/s]

{'loss': 0.2643, 'grad_norm': 1.1539223194122314, 'learning_rate': 4.421768707482993e-05, 'epoch': 1.16}


 12%|█▏        | 180/1470 [01:18<08:12,  2.62it/s]

{'loss': 0.2653, 'grad_norm': 0.6814648509025574, 'learning_rate': 4.387755102040816e-05, 'epoch': 1.22}


 13%|█▎        | 190/1470 [01:22<08:18,  2.57it/s]

{'loss': 0.2627, 'grad_norm': 1.087445616722107, 'learning_rate': 4.35374149659864e-05, 'epoch': 1.29}


 14%|█▎        | 200/1470 [01:27<08:37,  2.45it/s]

{'loss': 0.2477, 'grad_norm': 2.0018310546875, 'learning_rate': 4.319727891156463e-05, 'epoch': 1.36}



 14%|█▎        | 200/1470 [01:30<08:37,  2.45it/s]

{'eval_loss': 0.20628081262111664, 'eval_runtime': 3.1964, 'eval_samples_per_second': 91.664, 'eval_steps_per_second': 11.575, 'epoch': 1.36}


 14%|█▍        | 210/1470 [01:35<10:16,  2.05it/s]

{'loss': 0.2504, 'grad_norm': 0.8364963531494141, 'learning_rate': 4.2857142857142856e-05, 'epoch': 1.43}


 15%|█▍        | 220/1470 [01:40<09:48,  2.12it/s]

{'loss': 0.2509, 'grad_norm': 1.0136009454727173, 'learning_rate': 4.2517006802721085e-05, 'epoch': 1.5}


 16%|█▌        | 230/1470 [01:45<08:59,  2.30it/s]

{'loss': 0.2636, 'grad_norm': 0.8147462010383606, 'learning_rate': 4.217687074829932e-05, 'epoch': 1.56}


 16%|█▋        | 240/1470 [01:50<08:39,  2.37it/s]

{'loss': 0.2371, 'grad_norm': 2.2191452980041504, 'learning_rate': 4.183673469387756e-05, 'epoch': 1.63}


 17%|█▋        | 250/1470 [01:54<07:29,  2.71it/s]

{'loss': 0.2456, 'grad_norm': 2.1526572704315186, 'learning_rate': 4.149659863945579e-05, 'epoch': 1.7}



 17%|█▋        | 250/1470 [01:57<07:29,  2.71it/s]

{'eval_loss': 0.198953777551651, 'eval_runtime': 3.3424, 'eval_samples_per_second': 87.663, 'eval_steps_per_second': 11.07, 'epoch': 1.7}


 18%|█▊        | 260/1470 [02:01<08:38,  2.33it/s]

{'loss': 0.2281, 'grad_norm': 2.44036865234375, 'learning_rate': 4.1156462585034016e-05, 'epoch': 1.77}


 18%|█▊        | 270/1470 [02:05<07:18,  2.74it/s]

{'loss': 0.2344, 'grad_norm': 0.7456076741218567, 'learning_rate': 4.0816326530612245e-05, 'epoch': 1.84}


 19%|█▉        | 280/1470 [02:10<08:28,  2.34it/s]

{'loss': 0.2301, 'grad_norm': 0.917375922203064, 'learning_rate': 4.047619047619048e-05, 'epoch': 1.9}


 20%|█▉        | 290/1470 [02:14<08:34,  2.29it/s]

{'loss': 0.2286, 'grad_norm': 1.4033018350601196, 'learning_rate': 4.013605442176871e-05, 'epoch': 1.97}


 20%|██        | 300/1470 [02:18<08:12,  2.38it/s]

{'loss': 0.2251, 'grad_norm': 0.7006782293319702, 'learning_rate': 3.979591836734694e-05, 'epoch': 2.04}



 20%|██        | 300/1470 [02:22<08:12,  2.38it/s]

{'eval_loss': 0.1935480386018753, 'eval_runtime': 3.4847, 'eval_samples_per_second': 84.083, 'eval_steps_per_second': 10.618, 'epoch': 2.04}


 21%|██        | 310/1470 [02:26<07:58,  2.43it/s]

{'loss': 0.2156, 'grad_norm': 0.7530012130737305, 'learning_rate': 3.945578231292517e-05, 'epoch': 2.11}


 22%|██▏       | 320/1470 [02:31<07:54,  2.43it/s]

{'loss': 0.207, 'grad_norm': 0.515558123588562, 'learning_rate': 3.9115646258503405e-05, 'epoch': 2.18}


 22%|██▏       | 330/1470 [02:36<12:22,  1.54it/s]

{'loss': 0.2296, 'grad_norm': 2.2431561946868896, 'learning_rate': 3.8775510204081634e-05, 'epoch': 2.24}


 23%|██▎       | 340/1470 [02:42<08:36,  2.19it/s]

{'loss': 0.2172, 'grad_norm': 0.5580596923828125, 'learning_rate': 3.843537414965986e-05, 'epoch': 2.31}


 24%|██▍       | 350/1470 [02:47<10:07,  1.84it/s]

{'loss': 0.2368, 'grad_norm': 0.693956196308136, 'learning_rate': 3.809523809523809e-05, 'epoch': 2.38}



 24%|██▍       | 350/1470 [02:52<10:07,  1.84it/s]

{'eval_loss': 0.18862278759479523, 'eval_runtime': 4.1145, 'eval_samples_per_second': 71.212, 'eval_steps_per_second': 8.993, 'epoch': 2.38}


 24%|██▍       | 360/1470 [02:57<11:22,  1.63it/s]

{'loss': 0.2305, 'grad_norm': 0.5215088129043579, 'learning_rate': 3.775510204081633e-05, 'epoch': 2.45}


 25%|██▌       | 370/1470 [03:02<08:38,  2.12it/s]

{'loss': 0.2244, 'grad_norm': 0.4871145784854889, 'learning_rate': 3.7414965986394564e-05, 'epoch': 2.52}


 26%|██▌       | 380/1470 [03:07<10:25,  1.74it/s]

{'loss': 0.2125, 'grad_norm': 0.7567358613014221, 'learning_rate': 3.707482993197279e-05, 'epoch': 2.59}


 27%|██▋       | 390/1470 [03:12<07:57,  2.26it/s]

{'loss': 0.2021, 'grad_norm': 0.8796675801277161, 'learning_rate': 3.673469387755102e-05, 'epoch': 2.65}


 27%|██▋       | 400/1470 [03:16<07:40,  2.32it/s]

{'loss': 0.2127, 'grad_norm': 0.6543129682540894, 'learning_rate': 3.639455782312925e-05, 'epoch': 2.72}



 27%|██▋       | 400/1470 [03:20<07:40,  2.32it/s]

{'eval_loss': 0.18782266974449158, 'eval_runtime': 3.68, 'eval_samples_per_second': 79.619, 'eval_steps_per_second': 10.054, 'epoch': 2.72}


 28%|██▊       | 410/1470 [03:25<09:04,  1.94it/s]

{'loss': 0.2232, 'grad_norm': 0.5065346360206604, 'learning_rate': 3.605442176870749e-05, 'epoch': 2.79}


 29%|██▊       | 420/1470 [03:29<07:25,  2.36it/s]

{'loss': 0.2163, 'grad_norm': 0.6678345203399658, 'learning_rate': 3.571428571428572e-05, 'epoch': 2.86}


 29%|██▉       | 430/1470 [03:34<07:03,  2.46it/s]

{'loss': 0.209, 'grad_norm': 1.570968747138977, 'learning_rate': 3.5374149659863946e-05, 'epoch': 2.93}


 30%|██▉       | 440/1470 [03:40<10:58,  1.57it/s]

{'loss': 0.2002, 'grad_norm': 0.5817743539810181, 'learning_rate': 3.5034013605442175e-05, 'epoch': 2.99}


 31%|███       | 450/1470 [03:45<08:32,  1.99it/s]

{'loss': 0.2083, 'grad_norm': 0.46079134941101074, 'learning_rate': 3.469387755102041e-05, 'epoch': 3.06}



 31%|███       | 450/1470 [03:49<08:32,  1.99it/s]

{'eval_loss': 0.18311123549938202, 'eval_runtime': 4.096, 'eval_samples_per_second': 71.534, 'eval_steps_per_second': 9.033, 'epoch': 3.06}


 31%|███▏      | 460/1470 [03:54<07:36,  2.21it/s]

{'loss': 0.1952, 'grad_norm': 1.515023112297058, 'learning_rate': 3.435374149659864e-05, 'epoch': 3.13}


 32%|███▏      | 470/1470 [03:59<07:19,  2.28it/s]

{'loss': 0.2031, 'grad_norm': 0.7397546768188477, 'learning_rate': 3.401360544217687e-05, 'epoch': 3.2}


 33%|███▎      | 480/1470 [04:04<08:43,  1.89it/s]

{'loss': 0.2171, 'grad_norm': 0.5250166654586792, 'learning_rate': 3.36734693877551e-05, 'epoch': 3.27}


 33%|███▎      | 490/1470 [04:10<08:04,  2.02it/s]

{'loss': 0.1971, 'grad_norm': 0.515044093132019, 'learning_rate': 3.3333333333333335e-05, 'epoch': 3.33}


 34%|███▍      | 500/1470 [04:14<06:27,  2.50it/s]

{'loss': 0.208, 'grad_norm': 0.8916515111923218, 'learning_rate': 3.2993197278911564e-05, 'epoch': 3.4}



 34%|███▍      | 500/1470 [04:18<06:27,  2.50it/s]

{'eval_loss': 0.1817120462656021, 'eval_runtime': 3.8684, 'eval_samples_per_second': 75.741, 'eval_steps_per_second': 9.565, 'epoch': 3.4}


 35%|███▍      | 510/1470 [04:26<07:31,  2.13it/s]

{'loss': 0.2127, 'grad_norm': 0.46287021040916443, 'learning_rate': 3.265306122448979e-05, 'epoch': 3.47}


 35%|███▌      | 520/1470 [04:31<06:39,  2.38it/s]

{'loss': 0.1924, 'grad_norm': 0.7354789972305298, 'learning_rate': 3.231292517006803e-05, 'epoch': 3.54}


 36%|███▌      | 530/1470 [04:36<07:53,  1.99it/s]

{'loss': 0.2121, 'grad_norm': 0.7746792435646057, 'learning_rate': 3.1972789115646265e-05, 'epoch': 3.61}


 37%|███▋      | 540/1470 [04:41<06:29,  2.39it/s]

{'loss': 0.1986, 'grad_norm': 1.4049304723739624, 'learning_rate': 3.1632653061224494e-05, 'epoch': 3.67}


 37%|███▋      | 550/1470 [04:46<08:17,  1.85it/s]

{'loss': 0.2036, 'grad_norm': 0.7001219987869263, 'learning_rate': 3.1292517006802724e-05, 'epoch': 3.74}



 37%|███▋      | 550/1470 [04:51<08:17,  1.85it/s]

{'eval_loss': 0.1792188286781311, 'eval_runtime': 4.7746, 'eval_samples_per_second': 61.366, 'eval_steps_per_second': 7.749, 'epoch': 3.74}


 38%|███▊      | 560/1470 [04:56<07:24,  2.05it/s]

{'loss': 0.2055, 'grad_norm': 0.5151751041412354, 'learning_rate': 3.095238095238095e-05, 'epoch': 3.81}


 39%|███▉      | 570/1470 [05:02<08:31,  1.76it/s]

{'loss': 0.1997, 'grad_norm': 0.40025171637535095, 'learning_rate': 3.061224489795919e-05, 'epoch': 3.88}


 39%|███▉      | 580/1470 [05:08<07:54,  1.88it/s]

{'loss': 0.2047, 'grad_norm': 0.5332114696502686, 'learning_rate': 3.0272108843537418e-05, 'epoch': 3.95}


 40%|████      | 590/1470 [05:12<06:53,  2.13it/s]

{'loss': 0.1913, 'grad_norm': 0.9375244975090027, 'learning_rate': 2.9931972789115647e-05, 'epoch': 4.01}


 41%|████      | 600/1470 [05:21<10:42,  1.35it/s]

{'loss': 0.185, 'grad_norm': 0.4009455442428589, 'learning_rate': 2.959183673469388e-05, 'epoch': 4.08}



 41%|████      | 600/1470 [05:27<10:42,  1.35it/s]

{'eval_loss': 0.1792399138212204, 'eval_runtime': 6.0068, 'eval_samples_per_second': 48.778, 'eval_steps_per_second': 6.16, 'epoch': 4.08}


 41%|████▏     | 610/1470 [05:35<11:02,  1.30it/s]

{'loss': 0.1994, 'grad_norm': 0.5179342031478882, 'learning_rate': 2.925170068027211e-05, 'epoch': 4.15}


 42%|████▏     | 620/1470 [05:44<10:51,  1.31it/s]

{'loss': 0.2102, 'grad_norm': 0.4285053014755249, 'learning_rate': 2.891156462585034e-05, 'epoch': 4.22}


 43%|████▎     | 630/1470 [05:51<11:10,  1.25it/s]

{'loss': 0.1875, 'grad_norm': 0.6340763568878174, 'learning_rate': 2.857142857142857e-05, 'epoch': 4.29}


 44%|████▎     | 640/1470 [05:59<08:48,  1.57it/s]

{'loss': 0.198, 'grad_norm': 2.476027011871338, 'learning_rate': 2.8231292517006803e-05, 'epoch': 4.35}


 44%|████▍     | 650/1470 [06:05<07:18,  1.87it/s]

{'loss': 0.205, 'grad_norm': 0.5060441493988037, 'learning_rate': 2.7891156462585033e-05, 'epoch': 4.42}



 44%|████▍     | 650/1470 [06:11<07:18,  1.87it/s]

{'eval_loss': 0.17851261794567108, 'eval_runtime': 6.0014, 'eval_samples_per_second': 48.822, 'eval_steps_per_second': 6.165, 'epoch': 4.42}


 45%|████▍     | 660/1470 [06:21<15:25,  1.14s/it]

{'loss': 0.1974, 'grad_norm': 0.5780653357505798, 'learning_rate': 2.7551020408163265e-05, 'epoch': 4.49}


 46%|████▌     | 670/1470 [06:29<09:05,  1.47it/s]

{'loss': 0.1949, 'grad_norm': 0.5407294631004333, 'learning_rate': 2.72108843537415e-05, 'epoch': 4.56}


 46%|████▋     | 680/1470 [06:37<11:22,  1.16it/s]

{'loss': 0.1853, 'grad_norm': 0.46708691120147705, 'learning_rate': 2.687074829931973e-05, 'epoch': 4.63}


 47%|████▋     | 690/1470 [06:46<08:29,  1.53it/s]

{'loss': 0.1969, 'grad_norm': 0.5628882646560669, 'learning_rate': 2.6530612244897963e-05, 'epoch': 4.69}


 48%|████▊     | 700/1470 [06:54<09:55,  1.29it/s]

{'loss': 0.1948, 'grad_norm': 0.3492477536201477, 'learning_rate': 2.6190476190476192e-05, 'epoch': 4.76}



 48%|████▊     | 700/1470 [06:58<09:55,  1.29it/s]

{'eval_loss': 0.17731736600399017, 'eval_runtime': 4.2456, 'eval_samples_per_second': 69.012, 'eval_steps_per_second': 8.715, 'epoch': 4.76}


 48%|████▊     | 710/1470 [07:05<07:23,  1.72it/s]

{'loss': 0.2032, 'grad_norm': 1.5003972053527832, 'learning_rate': 2.5850340136054425e-05, 'epoch': 4.83}


 49%|████▉     | 720/1470 [07:10<05:57,  2.10it/s]

{'loss': 0.206, 'grad_norm': 0.5774499177932739, 'learning_rate': 2.5510204081632654e-05, 'epoch': 4.9}


 50%|████▉     | 730/1470 [07:16<05:03,  2.44it/s]

{'loss': 0.202, 'grad_norm': 0.4383357763290405, 'learning_rate': 2.5170068027210887e-05, 'epoch': 4.97}


 50%|█████     | 740/1470 [07:20<05:26,  2.23it/s]

{'loss': 0.1846, 'grad_norm': 0.4428597092628479, 'learning_rate': 2.4829931972789116e-05, 'epoch': 5.03}


 51%|█████     | 750/1470 [07:25<05:16,  2.28it/s]

{'loss': 0.1965, 'grad_norm': 0.4101014733314514, 'learning_rate': 2.448979591836735e-05, 'epoch': 5.1}



 51%|█████     | 750/1470 [07:29<05:16,  2.28it/s]

{'eval_loss': 0.17652666568756104, 'eval_runtime': 3.8363, 'eval_samples_per_second': 76.376, 'eval_steps_per_second': 9.645, 'epoch': 5.1}


 52%|█████▏    | 760/1470 [07:34<05:30,  2.15it/s]

{'loss': 0.1957, 'grad_norm': 0.75242680311203, 'learning_rate': 2.4149659863945578e-05, 'epoch': 5.17}


 52%|█████▏    | 770/1470 [07:38<05:20,  2.18it/s]

{'loss': 0.1769, 'grad_norm': 0.7737071514129639, 'learning_rate': 2.380952380952381e-05, 'epoch': 5.24}


 53%|█████▎    | 780/1470 [07:43<05:12,  2.21it/s]

{'loss': 0.1906, 'grad_norm': 0.49271151423454285, 'learning_rate': 2.3469387755102043e-05, 'epoch': 5.31}


 54%|█████▎    | 790/1470 [07:48<04:51,  2.33it/s]

{'loss': 0.1989, 'grad_norm': 0.4444643557071686, 'learning_rate': 2.3129251700680275e-05, 'epoch': 5.37}


 54%|█████▍    | 800/1470 [07:53<05:09,  2.16it/s]

{'loss': 0.1999, 'grad_norm': 0.5632987022399902, 'learning_rate': 2.2789115646258505e-05, 'epoch': 5.44}



 54%|█████▍    | 800/1470 [07:57<05:09,  2.16it/s]

{'eval_loss': 0.17601221799850464, 'eval_runtime': 3.7025, 'eval_samples_per_second': 79.137, 'eval_steps_per_second': 9.993, 'epoch': 5.44}


 55%|█████▌    | 810/1470 [08:01<05:22,  2.05it/s]

{'loss': 0.203, 'grad_norm': 0.36531081795692444, 'learning_rate': 2.2448979591836737e-05, 'epoch': 5.51}


 56%|█████▌    | 820/1470 [08:08<09:12,  1.18it/s]

{'loss': 0.179, 'grad_norm': 0.5537640452384949, 'learning_rate': 2.2108843537414966e-05, 'epoch': 5.58}


 56%|█████▋    | 830/1470 [08:14<07:19,  1.46it/s]

{'loss': 0.1874, 'grad_norm': 0.8863208293914795, 'learning_rate': 2.17687074829932e-05, 'epoch': 5.65}


 57%|█████▋    | 840/1470 [08:20<07:03,  1.49it/s]

{'loss': 0.1978, 'grad_norm': 0.47732725739479065, 'learning_rate': 2.1428571428571428e-05, 'epoch': 5.71}


 58%|█████▊    | 850/1470 [08:25<06:23,  1.62it/s]

{'loss': 0.1958, 'grad_norm': 0.3308524787425995, 'learning_rate': 2.108843537414966e-05, 'epoch': 5.78}



 58%|█████▊    | 850/1470 [08:29<06:23,  1.62it/s]

{'eval_loss': 0.17480464279651642, 'eval_runtime': 3.748, 'eval_samples_per_second': 78.174, 'eval_steps_per_second': 9.872, 'epoch': 5.78}


 59%|█████▊    | 860/1470 [08:34<05:06,  1.99it/s]

{'loss': 0.192, 'grad_norm': 1.0222359895706177, 'learning_rate': 2.0748299319727893e-05, 'epoch': 5.85}


 59%|█████▉    | 870/1470 [08:39<04:36,  2.17it/s]

{'loss': 0.1912, 'grad_norm': 0.8697932958602905, 'learning_rate': 2.0408163265306123e-05, 'epoch': 5.92}


 60%|█████▉    | 880/1470 [08:44<04:30,  2.18it/s]

{'loss': 0.1959, 'grad_norm': 0.6425577402114868, 'learning_rate': 2.0068027210884355e-05, 'epoch': 5.99}


 61%|██████    | 890/1470 [08:48<03:53,  2.49it/s]

{'loss': 0.1693, 'grad_norm': 0.2585643231868744, 'learning_rate': 1.9727891156462584e-05, 'epoch': 6.05}


 61%|██████    | 900/1470 [08:55<06:23,  1.49it/s]

{'loss': 0.1828, 'grad_norm': 0.4403946101665497, 'learning_rate': 1.9387755102040817e-05, 'epoch': 6.12}



 61%|██████    | 900/1470 [08:58<06:23,  1.49it/s]

{'eval_loss': 0.17481040954589844, 'eval_runtime': 3.8611, 'eval_samples_per_second': 75.884, 'eval_steps_per_second': 9.583, 'epoch': 6.12}


 62%|██████▏   | 910/1470 [09:04<05:10,  1.80it/s]

{'loss': 0.2083, 'grad_norm': 1.0532658100128174, 'learning_rate': 1.9047619047619046e-05, 'epoch': 6.19}


 63%|██████▎   | 920/1470 [09:10<04:55,  1.86it/s]

{'loss': 0.1976, 'grad_norm': 0.5361754298210144, 'learning_rate': 1.8707482993197282e-05, 'epoch': 6.26}


 63%|██████▎   | 930/1470 [09:15<04:08,  2.17it/s]

{'loss': 0.1906, 'grad_norm': 0.40411171317100525, 'learning_rate': 1.836734693877551e-05, 'epoch': 6.33}


 64%|██████▍   | 940/1470 [09:20<03:59,  2.21it/s]

{'loss': 0.1725, 'grad_norm': 0.4704535901546478, 'learning_rate': 1.8027210884353744e-05, 'epoch': 6.39}


 65%|██████▍   | 950/1470 [09:26<04:11,  2.07it/s]

{'loss': 0.1716, 'grad_norm': 0.3814546763896942, 'learning_rate': 1.7687074829931973e-05, 'epoch': 6.46}



 65%|██████▍   | 950/1470 [09:30<04:11,  2.07it/s]

{'eval_loss': 0.1749599575996399, 'eval_runtime': 4.1896, 'eval_samples_per_second': 69.935, 'eval_steps_per_second': 8.831, 'epoch': 6.46}


 65%|██████▌   | 960/1470 [09:35<04:04,  2.09it/s]

{'loss': 0.1813, 'grad_norm': 0.7732194662094116, 'learning_rate': 1.7346938775510206e-05, 'epoch': 6.53}


 66%|██████▌   | 970/1470 [09:41<05:08,  1.62it/s]

{'loss': 0.1966, 'grad_norm': 0.3930586278438568, 'learning_rate': 1.7006802721088435e-05, 'epoch': 6.6}


 67%|██████▋   | 980/1470 [09:46<04:11,  1.95it/s]

{'loss': 0.1891, 'grad_norm': 0.6089556217193604, 'learning_rate': 1.6666666666666667e-05, 'epoch': 6.67}


 67%|██████▋   | 990/1470 [09:51<04:11,  1.91it/s]

{'loss': 0.191, 'grad_norm': 1.453655481338501, 'learning_rate': 1.6326530612244897e-05, 'epoch': 6.73}


 68%|██████▊   | 1000/1470 [09:58<05:01,  1.56it/s]

{'loss': 0.1897, 'grad_norm': 0.6984356045722961, 'learning_rate': 1.5986394557823133e-05, 'epoch': 6.8}



 68%|██████▊   | 1000/1470 [10:02<05:01,  1.56it/s]

{'eval_loss': 0.1744009405374527, 'eval_runtime': 3.6732, 'eval_samples_per_second': 79.766, 'eval_steps_per_second': 10.073, 'epoch': 6.8}


 69%|██████▊   | 1010/1470 [10:13<07:25,  1.03it/s]

{'loss': 0.1886, 'grad_norm': 1.0650113821029663, 'learning_rate': 1.5646258503401362e-05, 'epoch': 6.87}


 69%|██████▉   | 1020/1470 [10:20<04:51,  1.54it/s]

{'loss': 0.1899, 'grad_norm': 0.37277984619140625, 'learning_rate': 1.5306122448979594e-05, 'epoch': 6.94}


 70%|███████   | 1030/1470 [10:27<04:41,  1.56it/s]

{'loss': 0.1892, 'grad_norm': 0.38087841868400574, 'learning_rate': 1.4965986394557824e-05, 'epoch': 7.01}


 71%|███████   | 1040/1470 [10:33<04:11,  1.71it/s]

{'loss': 0.1845, 'grad_norm': 0.45975854992866516, 'learning_rate': 1.4625850340136055e-05, 'epoch': 7.07}


 71%|███████▏  | 1050/1470 [10:38<03:37,  1.93it/s]

{'loss': 0.1888, 'grad_norm': 0.47883450984954834, 'learning_rate': 1.4285714285714285e-05, 'epoch': 7.14}



 71%|███████▏  | 1050/1470 [10:44<03:37,  1.93it/s]

{'eval_loss': 0.17394432425498962, 'eval_runtime': 5.2415, 'eval_samples_per_second': 55.9, 'eval_steps_per_second': 7.059, 'epoch': 7.14}


 72%|███████▏  | 1060/1470 [10:51<05:17,  1.29it/s]

{'loss': 0.1938, 'grad_norm': 0.40599313378334045, 'learning_rate': 1.3945578231292516e-05, 'epoch': 7.21}


 73%|███████▎  | 1070/1470 [10:59<04:12,  1.58it/s]

{'loss': 0.1921, 'grad_norm': 0.46577808260917664, 'learning_rate': 1.360544217687075e-05, 'epoch': 7.28}


 73%|███████▎  | 1080/1470 [11:06<04:10,  1.56it/s]

{'loss': 0.1792, 'grad_norm': 0.45846354961395264, 'learning_rate': 1.3265306122448982e-05, 'epoch': 7.35}


 74%|███████▍  | 1090/1470 [11:13<06:19,  1.00it/s]

{'loss': 0.1972, 'grad_norm': 0.4309147298336029, 'learning_rate': 1.2925170068027212e-05, 'epoch': 7.41}


 75%|███████▍  | 1100/1470 [11:20<04:31,  1.36it/s]

{'loss': 0.1862, 'grad_norm': 0.5886499881744385, 'learning_rate': 1.2585034013605443e-05, 'epoch': 7.48}



 75%|███████▍  | 1100/1470 [11:26<04:31,  1.36it/s]

{'eval_loss': 0.17304204404354095, 'eval_runtime': 5.6372, 'eval_samples_per_second': 51.976, 'eval_steps_per_second': 6.564, 'epoch': 7.48}


 76%|███████▌  | 1110/1470 [11:35<04:57,  1.21it/s]

{'loss': 0.1759, 'grad_norm': 0.42560529708862305, 'learning_rate': 1.2244897959183674e-05, 'epoch': 7.55}


 76%|███████▌  | 1120/1470 [11:43<04:29,  1.30it/s]

{'loss': 0.1774, 'grad_norm': 0.452411949634552, 'learning_rate': 1.1904761904761905e-05, 'epoch': 7.62}


 77%|███████▋  | 1130/1470 [11:51<04:06,  1.38it/s]

{'loss': 0.1961, 'grad_norm': 1.511451244354248, 'learning_rate': 1.1564625850340138e-05, 'epoch': 7.69}


 78%|███████▊  | 1140/1470 [11:59<03:38,  1.51it/s]

{'loss': 0.1697, 'grad_norm': 0.4581722617149353, 'learning_rate': 1.1224489795918369e-05, 'epoch': 7.76}


 78%|███████▊  | 1150/1470 [12:07<04:52,  1.09it/s]

{'loss': 0.193, 'grad_norm': 0.520982563495636, 'learning_rate': 1.08843537414966e-05, 'epoch': 7.82}



 78%|███████▊  | 1150/1470 [12:12<04:52,  1.09it/s]

{'eval_loss': 0.17355620861053467, 'eval_runtime': 4.9992, 'eval_samples_per_second': 58.61, 'eval_steps_per_second': 7.401, 'epoch': 7.82}


 79%|███████▉  | 1160/1470 [12:19<03:12,  1.61it/s]

{'loss': 0.175, 'grad_norm': 1.2562882900238037, 'learning_rate': 1.054421768707483e-05, 'epoch': 7.89}


 80%|███████▉  | 1170/1470 [12:27<03:37,  1.38it/s]

{'loss': 0.1855, 'grad_norm': 0.5053359270095825, 'learning_rate': 1.0204081632653061e-05, 'epoch': 7.96}


 80%|████████  | 1180/1470 [12:34<03:27,  1.40it/s]

{'loss': 0.2024, 'grad_norm': 0.5534549355506897, 'learning_rate': 9.863945578231292e-06, 'epoch': 8.03}


 81%|████████  | 1190/1470 [12:42<03:02,  1.54it/s]

{'loss': 0.1905, 'grad_norm': 0.499462753534317, 'learning_rate': 9.523809523809523e-06, 'epoch': 8.1}


 82%|████████▏ | 1200/1470 [12:50<03:26,  1.31it/s]

{'loss': 0.1738, 'grad_norm': 0.3859996199607849, 'learning_rate': 9.183673469387756e-06, 'epoch': 8.16}



 82%|████████▏ | 1200/1470 [12:56<03:26,  1.31it/s]

{'eval_loss': 0.1735721081495285, 'eval_runtime': 6.0044, 'eval_samples_per_second': 48.797, 'eval_steps_per_second': 6.162, 'epoch': 8.16}


 82%|████████▏ | 1210/1470 [13:04<03:04,  1.41it/s]

{'loss': 0.1763, 'grad_norm': 0.6806049942970276, 'learning_rate': 8.843537414965987e-06, 'epoch': 8.23}


 83%|████████▎ | 1220/1470 [13:12<02:53,  1.44it/s]

{'loss': 0.1925, 'grad_norm': 0.8156037926673889, 'learning_rate': 8.503401360544217e-06, 'epoch': 8.3}


 84%|████████▎ | 1230/1470 [13:20<03:50,  1.04it/s]

{'loss': 0.1868, 'grad_norm': 0.6439953446388245, 'learning_rate': 8.163265306122448e-06, 'epoch': 8.37}


 84%|████████▍ | 1240/1470 [13:28<03:04,  1.25it/s]

{'loss': 0.1816, 'grad_norm': 0.3833441734313965, 'learning_rate': 7.823129251700681e-06, 'epoch': 8.44}


 85%|████████▌ | 1250/1470 [13:39<02:50,  1.29it/s]

{'loss': 0.19, 'grad_norm': 0.9287965297698975, 'learning_rate': 7.482993197278912e-06, 'epoch': 8.5}



 85%|████████▌ | 1250/1470 [13:46<02:50,  1.29it/s]

{'eval_loss': 0.17364878952503204, 'eval_runtime': 6.9029, 'eval_samples_per_second': 42.446, 'eval_steps_per_second': 5.36, 'epoch': 8.5}


 86%|████████▌ | 1260/1470 [13:57<03:40,  1.05s/it]

{'loss': 0.1874, 'grad_norm': 0.5057103037834167, 'learning_rate': 7.142857142857143e-06, 'epoch': 8.57}


 86%|████████▋ | 1270/1470 [14:06<02:59,  1.11it/s]

{'loss': 0.1868, 'grad_norm': 0.4397478699684143, 'learning_rate': 6.802721088435375e-06, 'epoch': 8.64}


 87%|████████▋ | 1280/1470 [14:14<02:41,  1.18it/s]

{'loss': 0.1957, 'grad_norm': 1.0980876684188843, 'learning_rate': 6.462585034013606e-06, 'epoch': 8.71}


 88%|████████▊ | 1290/1470 [14:24<02:40,  1.12it/s]

{'loss': 0.198, 'grad_norm': 0.3966902494430542, 'learning_rate': 6.122448979591837e-06, 'epoch': 8.78}


 88%|████████▊ | 1300/1470 [14:33<02:10,  1.30it/s]

{'loss': 0.19, 'grad_norm': 0.5437226295471191, 'learning_rate': 5.782312925170069e-06, 'epoch': 8.84}



 88%|████████▊ | 1300/1470 [14:38<02:10,  1.30it/s]

{'eval_loss': 0.1731949895620346, 'eval_runtime': 5.6322, 'eval_samples_per_second': 52.022, 'eval_steps_per_second': 6.569, 'epoch': 8.84}


 89%|████████▉ | 1310/1470 [14:47<02:12,  1.21it/s]

{'loss': 0.1954, 'grad_norm': 0.4180458188056946, 'learning_rate': 5.4421768707483e-06, 'epoch': 8.91}


 90%|████████▉ | 1320/1470 [14:55<02:00,  1.25it/s]

{'loss': 0.1779, 'grad_norm': 0.42053288221359253, 'learning_rate': 5.102040816326531e-06, 'epoch': 8.98}


 90%|█████████ | 1330/1470 [15:02<01:39,  1.41it/s]

{'loss': 0.184, 'grad_norm': 0.5702763795852661, 'learning_rate': 4.7619047619047615e-06, 'epoch': 9.05}


 91%|█████████ | 1340/1470 [15:10<01:30,  1.43it/s]

{'loss': 0.1882, 'grad_norm': 0.5219219326972961, 'learning_rate': 4.421768707482993e-06, 'epoch': 9.12}


 92%|█████████▏| 1350/1470 [15:19<01:55,  1.04it/s]

{'loss': 0.1808, 'grad_norm': 0.32969433069229126, 'learning_rate': 4.081632653061224e-06, 'epoch': 9.18}



 92%|█████████▏| 1350/1470 [15:25<01:55,  1.04it/s]

{'eval_loss': 0.17303866147994995, 'eval_runtime': 6.1561, 'eval_samples_per_second': 47.595, 'eval_steps_per_second': 6.01, 'epoch': 9.18}


 93%|█████████▎| 1360/1470 [15:35<02:16,  1.24s/it]

{'loss': 0.1834, 'grad_norm': 0.7293479442596436, 'learning_rate': 3.741496598639456e-06, 'epoch': 9.25}


 93%|█████████▎| 1370/1470 [15:44<01:43,  1.03s/it]

{'loss': 0.1943, 'grad_norm': 0.40835410356521606, 'learning_rate': 3.4013605442176877e-06, 'epoch': 9.32}


 94%|█████████▍| 1380/1470 [15:50<00:52,  1.72it/s]

{'loss': 0.1863, 'grad_norm': 2.2532296180725098, 'learning_rate': 3.0612244897959185e-06, 'epoch': 9.39}


 95%|█████████▍| 1390/1470 [15:59<01:03,  1.26it/s]

{'loss': 0.1855, 'grad_norm': 0.42599207162857056, 'learning_rate': 2.72108843537415e-06, 'epoch': 9.46}


 95%|█████████▌| 1400/1470 [16:06<00:45,  1.55it/s]

{'loss': 0.1878, 'grad_norm': 0.5242931246757507, 'learning_rate': 2.3809523809523808e-06, 'epoch': 9.52}



 95%|█████████▌| 1400/1470 [16:11<00:45,  1.55it/s]

{'eval_loss': 0.1731421798467636, 'eval_runtime': 4.97, 'eval_samples_per_second': 58.953, 'eval_steps_per_second': 7.445, 'epoch': 9.52}


 96%|█████████▌| 1410/1470 [16:18<00:35,  1.69it/s]

{'loss': 0.1975, 'grad_norm': 0.4233034551143646, 'learning_rate': 2.040816326530612e-06, 'epoch': 9.59}


 97%|█████████▋| 1420/1470 [16:24<00:32,  1.55it/s]

{'loss': 0.1818, 'grad_norm': 0.4395368695259094, 'learning_rate': 1.7006802721088438e-06, 'epoch': 9.66}


 97%|█████████▋| 1430/1470 [16:29<00:19,  2.01it/s]

{'loss': 0.1825, 'grad_norm': 0.602171003818512, 'learning_rate': 1.360544217687075e-06, 'epoch': 9.73}


 98%|█████████▊| 1440/1470 [16:34<00:14,  2.09it/s]

{'loss': 0.1919, 'grad_norm': 0.5419273972511292, 'learning_rate': 1.020408163265306e-06, 'epoch': 9.8}


 99%|█████████▊| 1450/1470 [16:40<00:08,  2.27it/s]

{'loss': 0.1803, 'grad_norm': 0.6497805714607239, 'learning_rate': 6.802721088435375e-07, 'epoch': 9.86}



 99%|█████████▊| 1450/1470 [16:45<00:08,  2.27it/s]

{'eval_loss': 0.17309126257896423, 'eval_runtime': 4.5361, 'eval_samples_per_second': 64.592, 'eval_steps_per_second': 8.157, 'epoch': 9.86}


 99%|█████████▉| 1460/1470 [16:51<00:07,  1.38it/s]

{'loss': 0.1896, 'grad_norm': 0.8950478434562683, 'learning_rate': 3.4013605442176873e-07, 'epoch': 9.93}


100%|██████████| 1470/1470 [16:58<00:00,  1.51it/s]

{'loss': 0.1813, 'grad_norm': 0.5268719792366028, 'learning_rate': 0.0, 'epoch': 10.0}


100%|██████████| 1470/1470 [17:03<00:00,  1.44it/s]


{'train_runtime': 1023.446, 'train_samples_per_second': 11.432, 'train_steps_per_second': 1.436, 'train_loss': 0.3312650464018997, 'epoch': 10.0}
Input: 88 121 89 121 90
Predicted Sequence: 88 121 89 121 90 88 121


In [6]:
# Example prediction
input_sequence = '80 118 79 123 88 127'
predicted_sequence = generate_sequence(model, tokenizer, input_sequence)
print(f"Input: {input_sequence}")
print(f"Predicted Sequence: {predicted_sequence}")

Input: 80 118 79 123 88 127
Predicted Sequence: 80 118 79 123 88 127 88 127


In [3]:
import torch
print(torch.backends.mps.is_available())  # Should return True on M1 Mac
print(torch.backends.mps.is_built())      # Should return True if PyTorch was built with MPS support


True
True
