<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/TRANSFORMER_REASONING_MODEL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets -q
!pip install transformers -q
!pip install torch -q

In [None]:
# 1. Imports and Setup (Ensure libraries are installed)
# !pip install datasets -q
# !pip install transformers -q # Still needed for scheduler
# !pip install torch -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup # Using scheduler from transformers
from torch.optim import AdamW
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence # Keep pad_sequence for custom batching

print("Libraries imported.")

# 2. Hyperparameters (Adjust as needed, comments added)
# Consider increasing model capacity for reasoning tasks
D_MODEL = 256  # SUGGESTION: Increase (e.g., 512, 768)
NUM_LAYERS = 3 # SUGGESTION: Increase (e.g., 6, 8, 12)
D_FF = 512     # SUGGESTION: Increase (e.g., 4 * D_MODEL)

# Other hyperparameters from original code
BATCH_SIZE = 32
LEARNING_RATE = 1e-5 # SUGGESTION: Experiment with lower LR (e.g., 1e-5) and different schedulers for scratch training
NUM_EPOCHS = 150    # Reduced default, rely on early stopping
NUM_HEADS = 8
DROPOUT = 0.1
MAX_LEN = 128  # SUGGESTION: Analyze GSM8k lengths, consider increasing (e.g., 256, 512)
WARMUP_STEPS = 1000
GRADIENT_CLIPPING = 1.0
PATIENCE = 10 # For Early Stopping

print("Hyperparameters set.")

# 3. Load Dataset (Same as original)
print("Loading GSM8k dataset...")
gsm8k_dataset = load_dataset("gsm8k", "main")
train_dataset = gsm8k_dataset['train']
test_dataset = gsm8k_dataset['test']
print("Dataset loaded.")

# 4. Vocabulary Creation (Same as original - !!! MAJOR AREA FOR IMPROVEMENT !!!)
# SUGGESTION: This custom word-level vocabulary is likely insufficient.
# Strongly consider replacing this with a subword tokenizer (e.g., SentencePiece trained on GSM8k, or from Hugging Face).
print("Building custom vocabulary (NOTE: This is a potential bottleneck)...")
def build_vocabulary(examples):
    tokenizer = set()
    for example in examples:
        text = example['question'] + " " + example['answer']
        # Simple split - may not handle numbers/symbols well
        tokenizer.update(text.lower().split())
    return sorted(list(tokenizer))

vocabulary = build_vocabulary(train_dataset)
vocab_size = len(vocabulary)
word_to_index = {word: i for i, word in enumerate(vocabulary)}
index_to_word = {i: word for word, i in word_to_index.items()}

# Add special tokens (Same as original)
PAD_TOKEN = "<pad>"
START_TOKEN = "<start>"
END_TOKEN = "<end>"
UNK_TOKEN = "<unk>"

PAD_INDEX = 0
START_INDEX = vocab_size
END_INDEX = vocab_size + 1
UNK_INDEX = vocab_size + 2

word_to_index[PAD_TOKEN] = PAD_INDEX
word_to_index[START_TOKEN] = START_INDEX
word_to_index[END_TOKEN] = END_INDEX
word_to_index[UNK_TOKEN] = UNK_INDEX

index_to_word[PAD_INDEX] = PAD_TOKEN
index_to_word[START_INDEX] = START_TOKEN
index_to_word[END_INDEX] = END_TOKEN
index_to_word[UNK_INDEX] = UNK_TOKEN

updated_vocab_size = len(word_to_index)
print(f"Custom vocabulary built. Size: {updated_vocab_size}")


# 5. Data Processing Function (Same as original - tied to custom vocabulary)
# SUGGESTION: Adapt this significantly if using a better tokenizer.
def process_example(example, max_len, word_to_index):
    question = example['question'].lower().split()
    answer = example['answer'].lower().split() # Needs careful handling of GSM8k answer format

    question_tokens = [word_to_index.get(word, UNK_INDEX) for word in question]
    answer_tokens = [word_to_index.get(word, UNK_INDEX) for word in answer]

    # Create source and target sequences with special tokens
    src_tokens = [START_INDEX] + question_tokens + [END_INDEX]
    # Target input for decoder starts with START, target output ends with END
    tgt_input_tokens = [START_INDEX] + answer_tokens
    tgt_output_tokens = answer_tokens + [END_INDEX]

    # Truncate sequences if they exceed max_len
    src_tokens = src_tokens[:max_len]
    tgt_input_tokens = tgt_input_tokens[:max_len]
    # Ensure target output aligns with target input length after potential truncation
    tgt_output_tokens = tgt_output_tokens[:len(tgt_input_tokens)-1] + [END_INDEX] # Match length for loss

    # Padding (Source)
    src_padding_len = max_len - len(src_tokens)
    src_tensor = torch.tensor(src_tokens + [PAD_INDEX] * src_padding_len)

    # Padding (Target Input)
    tgt_input_padding_len = max_len - len(tgt_input_tokens)
    tgt_input_tensor = torch.tensor(tgt_input_tokens + [PAD_INDEX] * tgt_input_padding_len)

    # Padding (Target Output - for loss calculation)
    tgt_output_padding_len = max_len - len(tgt_output_tokens)
    tgt_output_tensor = torch.tensor(tgt_output_tokens + [PAD_INDEX] * tgt_output_padding_len)

    return src_tensor, tgt_input_tensor, tgt_output_tensor

# 6. Custom Dataset Class (Same as original)
class MathDataset(Dataset):
    def __init__(self, dataset, max_len, word_to_index):
        self.dataset = dataset
        self.max_len = max_len
        self.word_to_index = word_to_index

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        return process_example(example, self.max_len, self.word_to_index)

# 7. Collate Function (Same as original)
def collate_fn(batch):
    src_tensors, tgt_in_tensors, tgt_out_tensors = zip(*batch)
    src_tensors = pad_sequence(src_tensors, batch_first=True, padding_value=PAD_INDEX)
    tgt_in_tensors = pad_sequence(tgt_in_tensors, batch_first=True, padding_value=PAD_INDEX)
    tgt_out_tensors = pad_sequence(tgt_out_tensors, batch_first=True, padding_value=PAD_INDEX)
    return src_tensors, tgt_in_tensors, tgt_out_tensors

# 8. Create DataLoaders (Same as original)
print("Creating DataLoaders...")
train_math_dataset = MathDataset(train_dataset, MAX_LEN, word_to_index)
test_math_dataset = MathDataset(test_dataset, MAX_LEN, word_to_index)

train_dataloader = DataLoader(train_math_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_math_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
print("DataLoaders created.")

# 9. Transformer Model Definition (Same as original - Consider Pre-LN if increasing layers)
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output, attn_probs

    def split_heads(self, x):
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, num_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q_ = self.split_heads(self.W_q(Q))
        K_ = self.split_heads(self.W_k(K))
        V_ = self.split_heads(self.W_v(V))
        output, attn_probs = self.scaled_dot_product_attention(Q_, K_, V_, mask)
        output = self.combine_heads(output)
        output = self.W_o(output)
        return output, attn_probs

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model) # Post-LN
        self.ffn = PositionWiseFeedForward(d_model, d_ff)
        self.norm2 = nn.LayerNorm(d_model) # Post-LN
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output, _ = self.mha(x, x, x, mask)
        norm1_output = self.norm1(x + self.dropout(attn_output)) # Residual then Norm
        ffn_output = self.ffn(norm1_output)
        output = self.norm2(norm1_output + self.dropout(ffn_output)) # Residual then Norm
        return output

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.masked_mha = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model) # Post-LN
        self.enc_dec_mha = MultiHeadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model) # Post-LN
        self.ffn = PositionWiseFeedForward(d_model, d_ff)
        self.norm3 = nn.LayerNorm(d_model) # Post-LN
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        masked_attn_output, _ = self.masked_mha(x, x, x, tgt_mask)
        norm1_output = self.norm1(x + self.dropout(masked_attn_output))
        enc_dec_attn_output, _ = self.enc_dec_mha(norm1_output, enc_output, enc_output, src_mask) # Q=norm1_output, K=enc_output, V=enc_output
        norm2_output = self.norm2(norm1_output + self.dropout(enc_dec_attn_output))
        ffn_output = self.ffn(norm2_output)
        output = self.norm3(norm2_output + self.dropout(ffn_output))
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000): # Default max_len, ensure >= MAX_LEN hyperparameter
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1) # Removed this line
        self.register_buffer('pe', pe) # pe is now (max_len, d_model)

    def forward(self, x):
        # x is (batch_size, seq_len, d_model)
        # self.pe is (max_len, d_model)
        # We want to add the positional encoding for each position in the sequence
        # So we need to select the encodings up to seq_len
        seq_len = x.size(1)
        # Now we can add the positional encodings to the input embeddings
        return x + self.pe[:seq_len, :].unsqueeze(0) # Added unsqueeze(0) for batch dimension

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, max_len, dropout):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, mask):
        embedded = self.dropout(self.pos_encoding(self.embedding(src)))
        output = embedded
        for layer in self.layers:
            output = layer(output, mask)
        return output

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, max_len, dropout):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, enc_output, src_mask, tgt_mask):
        embedded = self.dropout(self.pos_encoding(self.embedding(tgt)))
        output = embedded
        for layer in self.layers:
            output = layer(output, enc_output, src_mask, tgt_mask)
        output = self.fc(output)
        return output

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, max_len, dropout):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, d_ff, max_len, dropout)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff, max_len, dropout)

    def make_src_mask(self, src):
        # Mask positions that are PAD_INDEX
        # Shape: [batch_size, 1, 1, src_len]
        src_mask = (src != PAD_INDEX).unsqueeze(1).unsqueeze(2)
        return src_mask

    def make_tgt_mask(self, tgt):
        # Mask positions that are PAD_INDEX and subsequent positions
        # Padding mask shape: [batch_size, 1, 1, tgt_len]
        padding_mask = (tgt != PAD_INDEX).unsqueeze(1).unsqueeze(2)
        # Subsequent mask shape: [1, 1, tgt_len, tgt_len]
        tgt_len = tgt.size(1)
        subsequent_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool().unsqueeze(0).unsqueeze(0)
        # Combined mask shape: [batch_size, 1, tgt_len, tgt_len]
        tgt_mask = padding_mask & subsequent_mask
        return tgt_mask

    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        enc_output = self.encoder(src, src_mask)
        output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        return output

print("Transformer model definition complete.")

# 10. Initialize Model, Optimizer, Scheduler, Loss
print("Initializing model, optimizer, scheduler...")
model = Transformer(updated_vocab_size, updated_vocab_size, D_MODEL, NUM_LAYERS, NUM_HEADS, D_FF, MAX_LEN, DROPOUT)

# SUGGESTION: Consider weight initialization schemes if needed (e.g., Xavier)
# for p in model.parameters():
#     if p.dim() > 1:
#         nn.init.xavier_uniform_(p)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01) # Added weight decay explicitly
total_steps = len(train_dataloader) * NUM_EPOCHS # Estimate total steps
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=total_steps)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_INDEX) # SUGGESTION: Consider label_smoothing=0.1
print("Initialization complete.")
print(f"Using device: {device}")

# 11. Training Loop with Early Stopping
print("Starting training loop...")
best_eval_loss = float('inf')
epochs_no_improve = 0

for epoch in range(NUM_EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
    # --- Training Phase ---
    model.train()
    total_train_loss = 0
    train_progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1} Train")

    for batch_idx, (src, tgt_in, tgt_out) in train_progress_bar:
        src = src.to(device)
        tgt_in = tgt_in.to(device)
        tgt_out = tgt_out.to(device) # Shape: [batch_size, max_len]

        optimizer.zero_grad()

        # Forward pass through the model
        # Input to decoder is tgt_in
        output = model(src, tgt_in) # Output shape: [batch_size, max_len, vocab_size]

        # Reshape for loss calculation
        # Loss expects [N, C] and [N]
        output_flat = output.view(-1, output.size(-1)) # Shape: [batch_size * max_len, vocab_size]
        tgt_out_flat = tgt_out.view(-1)              # Shape: [batch_size * max_len]

        loss = criterion(output_flat, tgt_out_flat)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIPPING)
        optimizer.step()
        scheduler.step() # Step scheduler each step

        total_train_loss += loss.item()
        train_progress_bar.set_postfix({"loss": loss.item()})

    avg_train_loss = total_train_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} Average Training Loss: {avg_train_loss:.4f}")

    # --- Evaluation Phase ---
    model.eval()
    total_eval_loss = 0
    eval_progress_bar = tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc=f"Epoch {epoch+1} Eval")

    with torch.no_grad():
        for batch_idx, (src, tgt_in, tgt_out) in eval_progress_bar:
            src = src.to(device)
            tgt_in = tgt_in.to(device)
            tgt_out = tgt_out.to(device)

            output = model(src, tgt_in)

            output_flat = output.view(-1, output.size(-1))
            tgt_out_flat = tgt_out.view(-1)

            loss = criterion(output_flat, tgt_out_flat)
            total_eval_loss += loss.item()
            eval_progress_bar.set_postfix({"loss": loss.item()})

    avg_eval_loss = total_eval_loss / len(test_dataloader)
    print(f"Epoch {epoch+1} Average Evaluation Loss: {avg_eval_loss:.4f}")

    # --- Early Stopping Logic ---
    if avg_eval_loss < best_eval_loss:
        best_eval_loss = avg_eval_loss
        epochs_no_improve = 0
        # Optional: Save the best model checkpoint
        # torch.save(model.state_dict(), "best_scratch_model.pth")
        print(f"Evaluation loss improved to {best_eval_loss:.4f}. Resetting patience.")
    else:
        epochs_no_improve += 1
        print(f"Evaluation loss did not improve for {epochs_no_improve} epoch(s).")

    if epochs_no_improve >= PATIENCE:
        print(f"Early stopping triggered after {epoch+1} epochs.")
        break # Exit the main training loop

print("\nTraining finished.")
# Optional: Load the best model if saved
# if os.path.exists("best_scratch_model.pth"):
#     model.load_state_dict(torch.load("best_scratch_model.pth"))
#     print("Loaded best model checkpoint for inference.")


# 12. Inference Function (Adapted from original - greedy decoding)
def translate_sentence(model_inf, sentence, word_to_index_inf, index_to_word_inf, max_len_inf, device_inf):
    model_inf.eval()

    # Tokenize input sentence
    tokens = [word_to_index_inf.get(word.lower(), UNK_INDEX) for word in sentence.split()]
    src_tokens = [START_INDEX] + tokens + [END_INDEX]
    src_tokens = src_tokens[:max_len_inf] # Truncate

    # Pad source sequence
    src_padding_len = max_len_inf - len(src_tokens)
    src_tensor = torch.tensor(src_tokens + [PAD_INDEX] * src_padding_len).unsqueeze(0).to(device_inf) # Add batch dim

    # Generate target sequence step-by-step (greedy)
    tgt_tokens = [START_INDEX]
    for i in range(max_len_inf - 1): # Max output length
        tgt_tensor = torch.tensor(tgt_tokens).unsqueeze(0).to(device_inf) # Add batch dim

        # Create masks for current input
        # src_mask and tgt_mask shapes need to be correct for model
        # Assuming batch_first=True was handled consistently:
        src_mask = model_inf.make_src_mask(src_tensor)
        tgt_mask = model_inf.make_tgt_mask(tgt_tensor)

        with torch.no_grad():
            # Encoder output only needs to be computed once
            if i == 0:
                enc_output = model_inf.encoder(src_tensor, src_mask)

            # Decoder output for the current target sequence
            output = model_inf.decoder(tgt_tensor, enc_output, src_mask, tgt_mask)
            # output shape: [1, current_tgt_len, vocab_size]

        # Get the prediction for the last token
        pred_token = output.argmax(2)[:, -1].item()
        tgt_tokens.append(pred_token)

        # Stop if END token is predicted
        if pred_token == END_INDEX:
            break

    # Convert token IDs back to words
    translated_words = [index_to_word_inf.get(token, UNK_TOKEN) for token in tgt_tokens if token != START_INDEX and token != END_INDEX]
    return " ".join(translated_words)


# 13. Example Inference (Using the trained scratch model)
print("\nRunning example inference with the trained scratch model...")

# Example question from the dataset (index 10)
sample_question = test_dataset[10]['question']
actual_answer = test_dataset[10]['answer']

# Generate prediction using the inference function
predicted_answer = translate_sentence(model, sample_question, word_to_index, index_to_word, MAX_LEN, device)

print("\n--- Example Inference (From-Scratch Model) ---")
print(f"Question: {sample_question}")
print(f"Actual Answer:\n{actual_answer}")
print("-" * 20)
print(f"Predicted Answer:\n{predicted_answer}") # Likely still poor without major changes
print("-" * 20)

print("Code execution structure complete.")

Libraries imported.
Hyperparameters set.
Loading GSM8k dataset...
Dataset loaded.
Building custom vocabulary (NOTE: This is a potential bottleneck)...
Custom vocabulary built. Size: 49241
Creating DataLoaders...
DataLoaders created.
Transformer model definition complete.
Initializing model, optimizer, scheduler...
Initialization complete.
Using device: cuda
Starting training loop...

--- Epoch 1/150 ---


Epoch 1 Train: 100%|██████████| 234/234 [00:24<00:00,  9.63it/s, loss=10.7]


Epoch 1 Average Training Loss: 10.8780


Epoch 1 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.88it/s, loss=10.8]


Epoch 1 Average Evaluation Loss: 10.6785
Evaluation loss improved to 10.6785. Resetting patience.

--- Epoch 2/150 ---


Epoch 2 Train: 100%|██████████| 234/234 [00:23<00:00,  9.77it/s, loss=9.97]


Epoch 2 Average Training Loss: 10.3369


Epoch 2 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.74it/s, loss=10.2]


Epoch 2 Average Evaluation Loss: 10.0151
Evaluation loss improved to 10.0151. Resetting patience.

--- Epoch 3/150 ---


Epoch 3 Train: 100%|██████████| 234/234 [00:24<00:00,  9.74it/s, loss=9.62]


Epoch 3 Average Training Loss: 9.7518


Epoch 3 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.56it/s, loss=9.73]


Epoch 3 Average Evaluation Loss: 9.5507
Evaluation loss improved to 9.5507. Resetting patience.

--- Epoch 4/150 ---


Epoch 4 Train: 100%|██████████| 234/234 [00:24<00:00,  9.70it/s, loss=9.01]


Epoch 4 Average Training Loss: 9.2849


Epoch 4 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.25it/s, loss=9.27]


Epoch 4 Average Evaluation Loss: 9.0808
Evaluation loss improved to 9.0808. Resetting patience.

--- Epoch 5/150 ---


Epoch 5 Train: 100%|██████████| 234/234 [00:24<00:00,  9.69it/s, loss=8.56]


Epoch 5 Average Training Loss: 8.7642


Epoch 5 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.52it/s, loss=8.76]


Epoch 5 Average Evaluation Loss: 8.5656
Evaluation loss improved to 8.5656. Resetting patience.

--- Epoch 6/150 ---


Epoch 6 Train: 100%|██████████| 234/234 [00:24<00:00,  9.66it/s, loss=7.92]


Epoch 6 Average Training Loss: 8.2469


Epoch 6 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.42it/s, loss=8.28]


Epoch 6 Average Evaluation Loss: 8.1065
Evaluation loss improved to 8.1065. Resetting patience.

--- Epoch 7/150 ---


Epoch 7 Train: 100%|██████████| 234/234 [00:24<00:00,  9.64it/s, loss=7.56]


Epoch 7 Average Training Loss: 7.8077


Epoch 7 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.40it/s, loss=7.9]


Epoch 7 Average Evaluation Loss: 7.7385
Evaluation loss improved to 7.7385. Resetting patience.

--- Epoch 8/150 ---


Epoch 8 Train: 100%|██████████| 234/234 [00:24<00:00,  9.63it/s, loss=7.33]


Epoch 8 Average Training Loss: 7.4784


Epoch 8 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.30it/s, loss=7.63]


Epoch 8 Average Evaluation Loss: 7.5010
Evaluation loss improved to 7.5010. Resetting patience.

--- Epoch 9/150 ---


Epoch 9 Train: 100%|██████████| 234/234 [00:24<00:00,  9.63it/s, loss=7.16]


Epoch 9 Average Training Loss: 7.2709


Epoch 9 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.28it/s, loss=7.46]


Epoch 9 Average Evaluation Loss: 7.3523
Evaluation loss improved to 7.3523. Resetting patience.

--- Epoch 10/150 ---


Epoch 10 Train: 100%|██████████| 234/234 [00:24<00:00,  9.56it/s, loss=7.14]


Epoch 10 Average Training Loss: 7.1280


Epoch 10 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.01it/s, loss=7.32]


Epoch 10 Average Evaluation Loss: 7.2376
Evaluation loss improved to 7.2376. Resetting patience.

--- Epoch 11/150 ---


Epoch 11 Train: 100%|██████████| 234/234 [00:24<00:00,  9.57it/s, loss=7.09]


Epoch 11 Average Training Loss: 7.0091


Epoch 11 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.15it/s, loss=7.22]


Epoch 11 Average Evaluation Loss: 7.1345
Evaluation loss improved to 7.1345. Resetting patience.

--- Epoch 12/150 ---


Epoch 12 Train: 100%|██████████| 234/234 [00:24<00:00,  9.58it/s, loss=6.71]


Epoch 12 Average Training Loss: 6.9038


Epoch 12 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.75it/s, loss=7.12]


Epoch 12 Average Evaluation Loss: 7.0406
Evaluation loss improved to 7.0406. Resetting patience.

--- Epoch 13/150 ---


Epoch 13 Train: 100%|██████████| 234/234 [00:24<00:00,  9.57it/s, loss=6.74]


Epoch 13 Average Training Loss: 6.8156


Epoch 13 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.22it/s, loss=7.02]


Epoch 13 Average Evaluation Loss: 6.9625
Evaluation loss improved to 6.9625. Resetting patience.

--- Epoch 14/150 ---


Epoch 14 Train: 100%|██████████| 234/234 [00:24<00:00,  9.57it/s, loss=6.87]


Epoch 14 Average Training Loss: 6.7219


Epoch 14 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.18it/s, loss=6.9]


Epoch 14 Average Evaluation Loss: 6.8643
Evaluation loss improved to 6.8643. Resetting patience.

--- Epoch 15/150 ---


Epoch 15 Train: 100%|██████████| 234/234 [00:24<00:00,  9.55it/s, loss=6.45]


Epoch 15 Average Training Loss: 6.6354


Epoch 15 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.04it/s, loss=6.84]


Epoch 15 Average Evaluation Loss: 6.8078
Evaluation loss improved to 6.8078. Resetting patience.

--- Epoch 16/150 ---


Epoch 16 Train: 100%|██████████| 234/234 [00:24<00:00,  9.55it/s, loss=6.38]


Epoch 16 Average Training Loss: 6.5766


Epoch 16 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.84it/s, loss=6.8]


Epoch 16 Average Evaluation Loss: 6.7585
Evaluation loss improved to 6.7585. Resetting patience.

--- Epoch 17/150 ---


Epoch 17 Train: 100%|██████████| 234/234 [00:24<00:00,  9.55it/s, loss=6.42]


Epoch 17 Average Training Loss: 6.5203


Epoch 17 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.07it/s, loss=6.75]


Epoch 17 Average Evaluation Loss: 6.7139
Evaluation loss improved to 6.7139. Resetting patience.

--- Epoch 18/150 ---


Epoch 18 Train: 100%|██████████| 234/234 [00:24<00:00,  9.54it/s, loss=6.26]


Epoch 18 Average Training Loss: 6.4712


Epoch 18 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.87it/s, loss=6.71]


Epoch 18 Average Evaluation Loss: 6.6752
Evaluation loss improved to 6.6752. Resetting patience.

--- Epoch 19/150 ---


Epoch 19 Train: 100%|██████████| 234/234 [00:24<00:00,  9.54it/s, loss=6.51]


Epoch 19 Average Training Loss: 6.4249


Epoch 19 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.97it/s, loss=6.66]


Epoch 19 Average Evaluation Loss: 6.6323
Evaluation loss improved to 6.6323. Resetting patience.

--- Epoch 20/150 ---


Epoch 20 Train: 100%|██████████| 234/234 [00:24<00:00,  9.55it/s, loss=6.43]


Epoch 20 Average Training Loss: 6.3778


Epoch 20 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.18it/s, loss=6.62]


Epoch 20 Average Evaluation Loss: 6.5985
Evaluation loss improved to 6.5985. Resetting patience.

--- Epoch 21/150 ---


Epoch 21 Train: 100%|██████████| 234/234 [00:24<00:00,  9.54it/s, loss=6.31]


Epoch 21 Average Training Loss: 6.3337


Epoch 21 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.10it/s, loss=6.59]


Epoch 21 Average Evaluation Loss: 6.5666
Evaluation loss improved to 6.5666. Resetting patience.

--- Epoch 22/150 ---


Epoch 22 Train: 100%|██████████| 234/234 [00:24<00:00,  9.54it/s, loss=6.19]


Epoch 22 Average Training Loss: 6.2943


Epoch 22 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.11it/s, loss=6.55]


Epoch 22 Average Evaluation Loss: 6.5301
Evaluation loss improved to 6.5301. Resetting patience.

--- Epoch 23/150 ---


Epoch 23 Train: 100%|██████████| 234/234 [00:24<00:00,  9.55it/s, loss=6.61]


Epoch 23 Average Training Loss: 6.2560


Epoch 23 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.05it/s, loss=6.53]


Epoch 23 Average Evaluation Loss: 6.5080
Evaluation loss improved to 6.5080. Resetting patience.

--- Epoch 24/150 ---


Epoch 24 Train: 100%|██████████| 234/234 [00:24<00:00,  9.54it/s, loss=6.25]


Epoch 24 Average Training Loss: 6.2168


Epoch 24 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.87it/s, loss=6.5]


Epoch 24 Average Evaluation Loss: 6.4774
Evaluation loss improved to 6.4774. Resetting patience.

--- Epoch 25/150 ---


Epoch 25 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=6.31]


Epoch 25 Average Training Loss: 6.1828


Epoch 25 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.00it/s, loss=6.46]


Epoch 25 Average Evaluation Loss: 6.4549
Evaluation loss improved to 6.4549. Resetting patience.

--- Epoch 26/150 ---


Epoch 26 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=6.3]


Epoch 26 Average Training Loss: 6.1494


Epoch 26 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.84it/s, loss=6.43]


Epoch 26 Average Evaluation Loss: 6.4266
Evaluation loss improved to 6.4266. Resetting patience.

--- Epoch 27/150 ---


Epoch 27 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=6.25]


Epoch 27 Average Training Loss: 6.1162


Epoch 27 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.86it/s, loss=6.41]


Epoch 27 Average Evaluation Loss: 6.4050
Evaluation loss improved to 6.4050. Resetting patience.

--- Epoch 28/150 ---


Epoch 28 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=6.29]


Epoch 28 Average Training Loss: 6.0865


Epoch 28 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.03it/s, loss=6.39]


Epoch 28 Average Evaluation Loss: 6.3835
Evaluation loss improved to 6.3835. Resetting patience.

--- Epoch 29/150 ---


Epoch 29 Train: 100%|██████████| 234/234 [00:24<00:00,  9.52it/s, loss=6.1]


Epoch 29 Average Training Loss: 6.0539


Epoch 29 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.83it/s, loss=6.36]


Epoch 29 Average Evaluation Loss: 6.3609
Evaluation loss improved to 6.3609. Resetting patience.

--- Epoch 30/150 ---


Epoch 30 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=6.39]


Epoch 30 Average Training Loss: 6.0269


Epoch 30 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.93it/s, loss=6.35]


Epoch 30 Average Evaluation Loss: 6.3422
Evaluation loss improved to 6.3422. Resetting patience.

--- Epoch 31/150 ---


Epoch 31 Train: 100%|██████████| 234/234 [00:24<00:00,  9.52it/s, loss=5.74]


Epoch 31 Average Training Loss: 5.9965


Epoch 31 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.12it/s, loss=6.33]


Epoch 31 Average Evaluation Loss: 6.3266
Evaluation loss improved to 6.3266. Resetting patience.

--- Epoch 32/150 ---


Epoch 32 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=6.26]


Epoch 32 Average Training Loss: 5.9708


Epoch 32 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.08it/s, loss=6.3]


Epoch 32 Average Evaluation Loss: 6.3071
Evaluation loss improved to 6.3071. Resetting patience.

--- Epoch 33/150 ---


Epoch 33 Train: 100%|██████████| 234/234 [00:24<00:00,  9.52it/s, loss=5.91]


Epoch 33 Average Training Loss: 5.9451


Epoch 33 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.03it/s, loss=6.29]


Epoch 33 Average Evaluation Loss: 6.2940
Evaluation loss improved to 6.2940. Resetting patience.

--- Epoch 34/150 ---


Epoch 34 Train: 100%|██████████| 234/234 [00:24<00:00,  9.51it/s, loss=5.6]


Epoch 34 Average Training Loss: 5.9178


Epoch 34 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.81it/s, loss=6.28]


Epoch 34 Average Evaluation Loss: 6.2740
Evaluation loss improved to 6.2740. Resetting patience.

--- Epoch 35/150 ---


Epoch 35 Train: 100%|██████████| 234/234 [00:24<00:00,  9.51it/s, loss=5.71]


Epoch 35 Average Training Loss: 5.8930


Epoch 35 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.80it/s, loss=6.27]


Epoch 35 Average Evaluation Loss: 6.2633
Evaluation loss improved to 6.2633. Resetting patience.

--- Epoch 36/150 ---


Epoch 36 Train: 100%|██████████| 234/234 [00:24<00:00,  9.51it/s, loss=6.03]


Epoch 36 Average Training Loss: 5.8695


Epoch 36 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.08it/s, loss=6.24]


Epoch 36 Average Evaluation Loss: 6.2452
Evaluation loss improved to 6.2452. Resetting patience.

--- Epoch 37/150 ---


Epoch 37 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=5.71]


Epoch 37 Average Training Loss: 5.8455


Epoch 37 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.03it/s, loss=6.23]


Epoch 37 Average Evaluation Loss: 6.2315
Evaluation loss improved to 6.2315. Resetting patience.

--- Epoch 38/150 ---


Epoch 38 Train: 100%|██████████| 234/234 [00:24<00:00,  9.52it/s, loss=5.82]


Epoch 38 Average Training Loss: 5.8242


Epoch 38 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.76it/s, loss=6.23]


Epoch 38 Average Evaluation Loss: 6.2200
Evaluation loss improved to 6.2200. Resetting patience.

--- Epoch 39/150 ---


Epoch 39 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=5.83]


Epoch 39 Average Training Loss: 5.8013


Epoch 39 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.96it/s, loss=6.2]


Epoch 39 Average Evaluation Loss: 6.2065
Evaluation loss improved to 6.2065. Resetting patience.

--- Epoch 40/150 ---


Epoch 40 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=5.89]


Epoch 40 Average Training Loss: 5.7799


Epoch 40 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.81it/s, loss=6.19]


Epoch 40 Average Evaluation Loss: 6.1939
Evaluation loss improved to 6.1939. Resetting patience.

--- Epoch 41/150 ---


Epoch 41 Train: 100%|██████████| 234/234 [00:24<00:00,  9.52it/s, loss=5.66]


Epoch 41 Average Training Loss: 5.7600


Epoch 41 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.03it/s, loss=6.17]


Epoch 41 Average Evaluation Loss: 6.1792
Evaluation loss improved to 6.1792. Resetting patience.

--- Epoch 42/150 ---


Epoch 42 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=5.79]


Epoch 42 Average Training Loss: 5.7375


Epoch 42 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.88it/s, loss=6.16]


Epoch 42 Average Evaluation Loss: 6.1719
Evaluation loss improved to 6.1719. Resetting patience.

--- Epoch 43/150 ---


Epoch 43 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=5.67]


Epoch 43 Average Training Loss: 5.7207


Epoch 43 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.12it/s, loss=6.15]


Epoch 43 Average Evaluation Loss: 6.1564
Evaluation loss improved to 6.1564. Resetting patience.

--- Epoch 44/150 ---


Epoch 44 Train: 100%|██████████| 234/234 [00:24<00:00,  9.55it/s, loss=5.58]


Epoch 44 Average Training Loss: 5.6994


Epoch 44 Eval: 100%|██████████| 42/42 [00:01<00:00, 29.02it/s, loss=6.15]


Epoch 44 Average Evaluation Loss: 6.1526
Evaluation loss improved to 6.1526. Resetting patience.

--- Epoch 45/150 ---


Epoch 45 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=5.61]


Epoch 45 Average Training Loss: 5.6809


Epoch 45 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.92it/s, loss=6.14]


Epoch 45 Average Evaluation Loss: 6.1414
Evaluation loss improved to 6.1414. Resetting patience.

--- Epoch 46/150 ---


Epoch 46 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=5.61]


Epoch 46 Average Training Loss: 5.6621


Epoch 46 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.84it/s, loss=6.12]


Epoch 46 Average Evaluation Loss: 6.1306
Evaluation loss improved to 6.1306. Resetting patience.

--- Epoch 47/150 ---


Epoch 47 Train: 100%|██████████| 234/234 [00:24<00:00,  9.54it/s, loss=5.68]


Epoch 47 Average Training Loss: 5.6443


Epoch 47 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.90it/s, loss=6.11]


Epoch 47 Average Evaluation Loss: 6.1194
Evaluation loss improved to 6.1194. Resetting patience.

--- Epoch 48/150 ---


Epoch 48 Train: 100%|██████████| 234/234 [00:24<00:00,  9.53it/s, loss=5.53]


Epoch 48 Average Training Loss: 5.6276


Epoch 48 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.84it/s, loss=6.11]


Epoch 48 Average Evaluation Loss: 6.1146
Evaluation loss improved to 6.1146. Resetting patience.

--- Epoch 49/150 ---


Epoch 49 Train: 100%|██████████| 234/234 [00:24<00:00,  9.52it/s, loss=5.54]


Epoch 49 Average Training Loss: 5.6103


Epoch 49 Eval: 100%|██████████| 42/42 [00:01<00:00, 28.75it/s, loss=6.1]


Epoch 49 Average Evaluation Loss: 6.1044
Evaluation loss improved to 6.1044. Resetting patience.

--- Epoch 50/150 ---


Epoch 50 Train:  48%|████▊     | 113/234 [00:11<00:12,  9.46it/s, loss=5.85]