# Setup

## Imports

In [None]:
# Import importlib to reload modules and sys and os to add the path for other imports
import importlib
import sys
import os
import torch

# Append the parent directory to the path to import the necessary modules
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# Import the utilities and the dataloader
from utils import trainutil, inferutil, setuputil

# Now reload the modules to ensure they are up-to-date
importlib.reload(setuputil)
importlib.reload(trainutil)
importlib.reload(inferutil)

# Import the funcs needed from utils
from utils.setuputil import setup_config, display_config
from utils.trainutil import train_model
from utils.inferutil import infer_one, infer_full


## Configuration

In [None]:
# Define the input config file
input_config = {
    # Environment and Model Info
    "env": "gcp",                
    "approach": "bert",         
    "model_name": "TestBert",     
    "model_base": "prajjwal1/bert-tiny",  
    
    # System Configuration
    "device": "cuda:0",
    "threads": 14,
    "seed": 42,
    
    # Data Configuration
    "data_dir": "../../data/farzan",
    "data_ds": "manual",
    
    # Model Parameters
    "rows": 100,
    "cols": 100,
    "tokens": 32,
    
    # Training Parameters
    "batch": 2,
    "lr": 1e-4,
    "mu": 0.25,
    "epochs": 3,
    "patience": 2,
    "save_int": 0,
    "save_dir": '../models/'
}
config = setup_config(input_config)
display_config(config)

In [None]:
# Define local variables from the config dictionary
DEVICE = config["DEVICE"]
THREADS = config["THREADS"]
gber42_TinyBert_manual_100x100x32
# Data loaders and vocab
train_loader = config["train_loader"]
val_loader = config["val_loader"]
test_loader = config["test_loader"]
tokenizer = config["tokenizer"]
model_base = config['model_base']

# Training parameters
batch_size = config["batch"]
lr = config["lr"]
mu = config["mu"]
epochs = config["epochs"]
patience = config["patience"]
save_int = config["save_int"]
save_dir = config["save_dir"]
save_name = config["save_name"]

## Checker Code for Loader Content

In [None]:
# # Check
# # Retrieve the first item's data from train_loader
# first_item = train_loader[0]

# # Extract tensors and file path
# x_tok = first_item['x_tok']
# x_masks = first_item['x_masks']
# y_tok = first_item['y_tok']
# file_path = first_item['file_paths']

# # Print the file path first
# print("File path:", file_path)

# # Print the shapes of the tensors
# print("Shape of x_tok:", x_tok.shape)
# print("Shape of x_masks:", x_masks.shape)
# print("Shape of y_tok:", y_tok.shape)

# # Define cell location
# row = 3
# col = 5

# # Extract data for the specific cell at (row, col)
# xtok_cell = x_tok[row, col, :]  # Tokenized input IDs for the cell
# xmask_cell = x_masks[row, col, :] if x_masks.numel() > 0 else None  # Attention mask for the cell (if applicable)
# ytok_cell = y_tok[row, col, :]  # Metadata tensor for the cell


# # Print extracted cell data
# print(f"\nx_tok at cell ({row}, {col}):\n", xtok_cell.tolist())
# if xmask_cell is not None:
#     print(f"\nx_masks at cell ({row}, {col}):\n", xmask_cell.tolist())
# print(f"\ny_tok at cell ({row}, {col}):\n", ytok_cell.tolist())

# # Decode x_tok of the cell into a list of words using the tokenizer
# decoded_words = tokenizer.decode(xtok_cell.tolist(), skip_special_tokens=False).split()
# print(f"\nDecoded words at cell ({row}, {col}):\n", decoded_words)


# Model Creation

In [None]:
# # Imports
# import torch
# import torch.nn as nn
# from transformers import AutoModel
# from tqdm import tqdm

# # Test model using tinybert for us
# class BertTiny(nn.Module):
#     def __init__(self, model_base="prajjwal1/bert-tiny", dropout_rate=0.05):
#         super(BertTiny, self).__init__()

#         # 1. Load pretrained BERT
#         self.bert = AutoModel.from_pretrained(model_base)

#         # 2. Define a dropout
#         self.dropout = nn.Dropout(dropout_rate)

#         # 3. Non-linear activation (GELU)
#         self.gelu = nn.GELU()

#         # 4. Final predictor (1-dim output per cell)
#         self.classifier = nn.Linear(self.bert.config.hidden_size, 1)

#     def forward(self, input_ids, attention_mask):

#         # 1. Print the overall shapes
#         # print("batch_size:", input_ids.shape[0])
#         # print("rows:",      input_ids.shape[1])
#         # print("cols:",      input_ids.shape[2])
#         # print("tokens:",    input_ids.shape[3])

#         # 2. Initialize S_cube => (batch_size, rows, cols)
#         S_cube = torch.zeros(
#             (input_ids.shape[0], input_ids.shape[1], input_ids.shape[2]),
#             device=input_ids.device
#         )

#         # 3. Loop over all cells
#         for cell in tqdm(range(input_ids.shape[1] * input_ids.shape[2]), desc = 'Forward'):

#             r = cell // input_ids.shape[2]
#             c = cell %  input_ids.shape[2]

#             # Extract the slice for current cell (batch_size x tokens)
#             cell_input_ids  = input_ids[:, r, c, :]
#             cell_attn_mask  = attention_mask[:, r, c, :]

#             # Pass them through the BERT model
#             outputs = self.bert(cell_input_ids, attention_mask=cell_attn_mask)

#             # pooler_out => (batch_size, hidden_dim)
#             pooler_out = outputs.pooler_output

#             # Inlined pipeline: dropout -> GELU -> classifier => (batch_size, 1)
#             logits = self.classifier(self.gelu(self.dropout(pooler_out)))

#             # Flatten (batch_size, 1) => (batch_size,)
#             logits_flat = logits.view(-1)

#             # Populate S_cube => shape: (batch_size, rows, cols)
#             S_cube[:, r, c] = logits_flat

#             # If this is the first cell, do some prints and break
#             # if r == 0 and c == 0:
#             #     print(f"\nFirst cell => row={r}, col={c}")
#             #     print(f"cell_input_ids.shape: {cell_input_ids.shape}")
#             #     print(f"cell_attn_mask.shape: {cell_attn_mask.shape}")
#             #     print(f"logits.shape: {logits.shape}")
#             #     print(f"logits_flat.shape: {logits_flat.shape}")
#             #     print(f"S_cube[:, {r}, {c}].shape: {S_cube[:, r, c].shape}")

#                 #break  # Stop after the first cell

#         # 4. Print the shape of S_cube
#         # print(f"\nS_cube.shape: {S_cube.shape}")

#         # Return S_cube or None, depending on your use case
#         return S_cube

    


In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel
from tqdm import tqdm


class BertTiny(nn.Module):
    def __init__(self, model_base="bert-base-cased", dropout_rate=0.05):
        super(BertTiny, self).__init__()

        # 1. Load pretrained BERT
        self.bert = AutoModel.from_pretrained(model_base)

        # 2. Define a dropout
        self.dropout = nn.Dropout(dropout_rate)

        # 3. Non-linear activation (GELU)
        self.gelu = nn.GELU()

        # 4. Final predictor (1-dim output per cell)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):

        # 1) Allocate the (batch_size, rows, cols) S_cube
        S_cube = torch.zeros(
            (input_ids.shape[0], input_ids.shape[1], input_ids.shape[2]),
            device=input_ids.device,
        )

        # 2) Loop over cells in row-major order
        for cell in range(input_ids.shape[1] * input_ids.shape[2]):

            # In one shot, store logits → S_cube
            # cell // input_ids.shape[2] = row, cell % input_ids.shape[2] = col
            S_cube[
                :, cell // input_ids.shape[2], cell % input_ids.shape[2]
            ] = self.classifier(
                self.gelu(
                    self.dropout(
                        self.bert(
                            input_ids[
                                :,
                                cell // input_ids.shape[2],
                                cell % input_ids.shape[2],
                                :,
                            ],
                            attention_mask=attention_mask[
                                :,
                                cell // input_ids.shape[2],
                                cell % input_ids.shape[2],
                                :,
                            ],
                        ).pooler_output
                    )
                )
            ).view(
                -1
            )

        return S_cube

In [None]:
# 1) Create model and move to GPU Observe its architecture
untrained_model = BertTiny(model_base=model_base).to(DEVICE)
print(untrained_model)

In [None]:
# # 2) Single-batch DataLoader
# check_loader = torch.utils.data.DataLoader(train_loader, batch_size=2, shuffle=False)
# batch = next(iter(check_loader))

# ex_xtok = batch["x_tok"].to(DEVICE)
# ex_xmask = batch["x_masks"].to(DEVICE)

# # Forward pass for a single batch
# output = untrained_model(input_ids=ex_xtok, attention_mask=ex_xmask)

# # Print the output shape
# print("Output shape (S_cube):", output.shape)


In [None]:
# Imports
import os  # For file and directory operations
import time  # For generating the timestamp in filenames
import torch  # Core PyTorch library
import torch.nn as nn  # For defining loss functions
import math  # For calculating exponential in perplexity calculation
from tqdm import tqdm  # For progress bars in training and validation loops
import sys
from sklearn.metrics import precision_score, recall_score, f1_score


# ------------------------------------------------------------------------
# Define a new function to train the BertTiny model using attention masks
def train_bert(model, train_data, val_data, DEVICE, batch_size=8, lr=1.4e-5, mu=0.25, max_epochs=4, patience=3, save_int=2, save_dir='../models/', save_name='bert_', config=None):
    
    # --------------------------------------------------------------------
    # Everything remains the same up until we get to the forward pass. 
    # We still set up logging, create train_loader, val_loader, define loss, etc.
    # --------------------------------------------------------------------
    
    # Set the option in torch to print full tensor
    torch.set_printoptions(profile="full")
    
    # Check if save_int > 0 and save_dir exists
    if save_int > 0 and not os.path.exists(save_dir):
        raise ValueError(f"Directory '{save_dir}' DNE")
    
    # Generate timestamp for naming checkpoints and logs
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    
    # Construct checkpoint paths
    model_path = os.path.join(save_dir, f"{save_name}_{timestamp}.pth")
    log_file = os.path.join(save_dir, f"{save_name}_{timestamp}.txt")
    
    # Write config to log if provided (and remove non-serializable items)
    if config is not None and save_int > 0:
        import json
        import copy
        
        config_serializable = copy.deepcopy(config)
        del config_serializable["DEVICE"]
        del config_serializable["train_loader"]
        del config_serializable["val_loader"]
        del config_serializable["test_loader"]
        
        with open(log_file, 'w') as log:
            log.write("\nFinal configuration:\n")
            log.write(json.dumps(config_serializable, indent=2))
            log.write("\n\n" + "="*80 + "\n\n")
    
    # --------------------------------------------------------------------
    # Create optimizer as before
    opt = torch.optim.Adagrad(model.parameters(), lr=lr)
    
    # Create the DataLoader for train and validation sets
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
    
    # Calculate class imbalance as before
    num_bold_cells = sum((batch['y_tok'][:, :, :, 6] == 1).sum() for batch in train_loader)
    num_nonbold_cells = sum((batch['y_tok'][:, :, :, 6] == 0).sum() for batch in train_loader)
    class_imbalance = num_nonbold_cells / num_bold_cells
    
    # Binary cross-entropy loss with logits
    loss_fn = nn.BCEWithLogitsLoss(
        pos_weight=torch.tensor([class_imbalance], dtype=torch.float).to(DEVICE)
    )
    
    # Initialize training parameters
    epoch = 0
    best_avgtrloss = float('inf')
    best_perp = float('inf')
    best_epoch = 0
    best_avgvalloss = float('inf')
    best_valperp = float('inf')
    nimp_ctr = 0
    training = True
    
    # --------------------------------------------------------------------
    # Main training loop
    # --------------------------------------------------------------------
    while training and (epoch < max_epochs):
        
        print(f'Epoch {epoch}')
        if save_int > 0:
            with open(log_file, 'a') as log:
                log.write(f"\nEpoch {epoch}\n")
        
        curr_trloss, curr_valloss = 0, 0
        
        # Put model in train mode
        model.train()
        
        # ----------------------------------------------------------------
        # Train step
        # ----------------------------------------------------------------
        for i, batch in enumerate(tqdm(train_loader, desc='Batch Processing')):
            
            # Zero out gradients
            model.zero_grad()
            
            # ----------------------------------------------------------------
            # CHANGED LINE: Now pass both input_ids and attention_mask to model
            logits = model(
                batch['x_tok'].to(DEVICE),
                batch['x_masks'].to(DEVICE)   # <--- Pass attention_mask here
            ).view(-1)
            
            # ----------------------------------------------------------------
            # Same as original: define labels
            labels = batch['y_tok'][:, :, :, 6].to(DEVICE).view(-1).float()
            
            # Compute loss
            loss = loss_fn(logits, labels)
            
            # Accumulate training loss
            curr_trloss += loss.detach().cpu().item()
            
            # Backprop
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=mu)
            
            # Update model parameters
            opt.step()
            
            # Clear memory
            del loss
        
        # Put model in eval mode
        model.eval()
        
        # ----------------------------------------------------------------
        # Validation step
        # ----------------------------------------------------------------
        for i, batch in enumerate(tqdm(val_loader, desc='Validation Processing')):
            with torch.no_grad():
                
                # ----------------------------------------------------------------
                # CHANGED LINE: Pass both input_ids and attention_mask to model
                val_logits = model(
                    batch['x_tok'].to(DEVICE),
                    batch['x_masks'].to(DEVICE)  # <--- Pass attention_mask here
                ).view(-1)
                
                # Labels remain the same
                val_labels = batch['y_tok'][:, :, :, 6].to(DEVICE).view(-1).float()
                
                # Compute validation loss
                val_loss = loss_fn(val_logits, val_labels)
                
                curr_valloss += val_loss.detach().cpu().item()
        
        # ----------------------------------------------------------------
        # Same perplexity calculations as original
        # ----------------------------------------------------------------
        curr_avgtrloss = curr_trloss / len(train_loader)
        curr_perp = math.exp(curr_trloss / (len(train_loader) * batch_size * 2500))
        curr_avgvalloss = curr_valloss / len(val_loader)
        curr_valperp = math.exp(curr_valloss / (len(val_loader) * batch_size * 2500))
        
        # Print stats
        print(f'Train Loss: {curr_avgtrloss}, Perplexity: {curr_perp}')
        print(f'Val Loss: {curr_avgvalloss}, Perplexity: {curr_valperp}\n')
        if save_int > 0:
            with open(log_file, 'a') as log:
                log.write(f'Train Loss: {curr_avgtrloss}, Perplexity: {curr_perp}\n')
                log.write(f'Val Loss: {curr_avgvalloss}, Perplexity: {curr_valperp}\n')
        
        # Early stopping checks
        if curr_valperp < best_valperp:
            best_perp = curr_perp
            best_valperp = curr_valperp
            best_avgtrloss = curr_avgtrloss
            best_avgvalloss = curr_avgvalloss
            best_epoch = epoch
            nimp_ctr = 0
        else:
            nimp_ctr += 1
        
        if nimp_ctr >= patience:
            print(f'\nEARLY STOPPING at epoch {epoch}, best epoch {best_epoch}')
            print(f'Train Loss = {best_avgtrloss}, Perplexity = {best_perp}')
            print(f'Val Loss = {best_avgvalloss}, Perplexity = {best_valperp}')
            if save_int > 0:
                with open(log_file, 'a') as log:
                    log.write(f'\nEARLY STOPPING at epoch {epoch}, best epoch {best_epoch}\n')
                    log.write(f'Train Loss = {best_avgtrloss}, Perplexity = {best_perp}\n')
                    log.write(f'Val Loss = {best_avgvalloss}, Perplexity = {best_valperp}\n')
            training = False
        
        # Save model periodically
        if save_int > 0 and (epoch + 1) % save_int == 0:
            torch.save(model.state_dict(), model_path)
            print("Model Saved")
            with open(log_file, 'a') as log:
                log.write("Model Saved\n")
        
        epoch += 1
        print()
    
    # Final save
    if save_int > 0:
        torch.save(model.state_dict(), model_path)
    
    # Print final results
    print(f'\nTRAINING DONE at epoch {epoch-1}, best epoch {best_epoch}')
    print(f'Train Loss = {best_avgtrloss}, Perplexity = {best_perp}')
    print(f'Val Loss = {best_avgvalloss}, Perplexity = {best_valperp}')
    if save_int > 0:
        with open(log_file, 'a') as log:
            log.write(f'\nTRAINING DONE at epoch {epoch-1}, best epoch {best_epoch}\n')
            log.write(f'Train Loss = {best_avgtrloss}, Perplexity = {best_perp}\n')
            log.write(f'Val Loss = {best_avgvalloss}, Perplexity = {best_valperp}\n')
    
    return model


In [None]:
# # Call the train_bert function with the loaded model and config hyperparameters
trained_model = train_bert(
    untrained_model,                  # Pass the BertTiny model
    train_loader,   # Training dataset
    val_loader,     # Validation dataset
    DEVICE,                 # Device for computation (CPU/GPU)
    batch_size=batch_size,  # Batch size from config
    lr=lr,                  # Learning rate from config
    mu=mu,                  # Gradient clipping max norm from config
    max_epochs=epochs,      # Maximum number of epochs from config
    patience=patience,      # Early stopping patience
    save_int=save_int,      # Interval at which to save model
    save_dir=save_dir,      # Directory path to save checkpoints
    save_name=save_name,    # Base name used for saving checkpoints/logs
    config=config           # Full config for logging
)


# # Call the train_bert function with the loaded model and config hyperparameters
# trained_model = train_bert(
#     untrained_model,                  # Pass the BertTiny model
#     train_loader,   # Training dataset
#     val_loader,     # Validation dataset
#     DEVICE,                 # Device for computation (CPU/GPU)
#     batch_size=2,  # Batch size from config
#     lr=1e-5,                  # Learning rate from config
#     mu=mu,                  # Gradient clipping max norm from config
#     max_epochs=4,      # Maximum number of epochs from config
#     patience=2,      # Early stopping patience
#     save_int=0,      # Interval at which to save model
#     save_dir=save_dir,      # Directory path to save checkpoints
#     save_name=save_name,    # Base name used for saving checkpoints/logs
#     config=config           # Full config for logging
# )


In [None]:
from utils import inferutil
importlib.reload(inferutil)
from utils.inferutil import binfer_one

# Params
loc = 0
thresh = 0.6
cond = '>'
disp_max=True

# inference on single position of train loader params
binfer_one(
    trained_model,
    train_loader,
    loc=loc,
    threshold=thresh,
    condition=cond,
    disp_max=disp_max,
    device=DEVICE
)

In [None]:
binfer_one(
    trained_model,
    val_loader,
    loc=loc,
    threshold=thresh,
    condition=cond,
    disp_max=disp_max,
    device=DEVICE
)