<a href="https://colab.research.google.com/github/envomp/predicate_logic_training/blob/main/predicate_logic_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
!pip install einx dataclasses_json llama_models

Collecting llama_models
  Downloading llama_models-0.0.55-py3-none-any.whl.metadata (8.2 kB)
Collecting tiktoken (from llama_models)
  Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Downloading llama_models-0.0.55-py3-none-any.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m47.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken, llama_models
Successfully installed llama_models-0.0.55 tiktoken-0.8.0


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [18]:
from hierarchical_routing import *
import processor
from eval import eval_model, visualize_routes
from data_preprocessing import pad_collate

In [None]:
for_classification = False
train_ds = processor.load("../dataset/predicate_logic/train/70k/prop_examples_lp.txt", for_classification=for_classification)
validation_ds = processor.load("../dataset/predicate_logic/validation/prop_examples_lp.txt", for_classification=for_classification)
inference_ids = processor.load("../dataset/predicate_logic/validation/prop_examples_lp.txt", for_classification=for_classification)

def ds_loader(ds, ds_length, epoch):
    # curriculum = processor.train_curriculum(ds, epoch, select_layer_items=ds_length // 7, non_select_layer_items=ds_length // 140)
    return DataLoader(ds, shuffle=True, batch_size=12, collate_fn=lambda x: pad_collate(x, padding=processor.pad))

def loss_fn(llm_output, depth):
    logits, labels, gating, history = llm_output

    if for_classification:
        logit = logits[:, -1]
        total_loss = F.binary_cross_entropy_with_logits(logit, labels)
    else:
        total_loss = F.cross_entropy(logits.transpose(1, 2), labels, label_smoothing=0.)

    # layers_traversed_loss = 0.0
    # for next_layer_prob in gating:
    #     layers_traversed_loss += F.binary_cross_entropy(next_layer_prob, torch.ones_like(next_layer_prob))

    # # whatever you think, be confident in it
    # confidence_loss = 0.0
    # for next_layer_prob in gating:
    #     confidence = 0.5 - torch.abs(next_layer_prob - 0.5)
    #     confidence_loss += confidence.mean()
    # confidence_loss *= 0.01
    # confidence_loss /= len(gating)
    # total_loss += confidence_loss

    # fixed routing
    total_routing_loss = 0
    for pos, gates in enumerate(gating):
        routing_loss = 0
        for i in range(len(gates)):
            next_layer_prob = gates[i]
            is_forward = is_forward_at_position_for_depth(i, depth[pos])
            target_prob = torch.ones_like(next_layer_prob) if is_forward > 0 else torch.zeros_like(next_layer_prob)
            routing_loss += F.binary_cross_entropy(next_layer_prob, target_prob)
        total_routing_loss += routing_loss / len(gates)
    total_loss += total_routing_loss / len(gating)

    return total_loss

def inference_model(llm_model, epoch=0):
    if for_classification:
        vocabulary, answer_position = {1: 1, 0: 0}, 1
    else:
        vocabulary, answer_position = processor.special_tokens, 0

    with torch.no_grad():
        routes = []
        def model_invocation(xs, labels, depth, global_attention_mask, **kwargs):
            labels = torch.tensor([labels]).to("cuda")
            depths = torch.tensor([depth]).to("cuda")
            global_attention_mask = torch.tensor([global_attention_mask]).to("cuda")
            logits, labels, gates, history = llm_model(xs, labels=labels, depths=depths, global_attention_mask=global_attention_mask)
            routes.append([x.item() for x in history])
            return logits
        # eval_model(model_invocation, [x for x in inference_ids if x["depth"] <= epoch], vocabulary=vocabulary, answer_position=answer_position)
        eval_model(model_invocation, inference_ids, vocabulary=vocabulary, answer_position=answer_position)
        visualize_routes(routes)

def call_model(llm, batch, loss_fn):
    input_ids = batch["input_ids"].to("cuda")
    labels = batch["labels"].to("cuda")
    depths = batch["depth"].to("cuda")
    global_attention_mask = batch["global_attention_mask"].to("cuda")

    llm_output = llm.forward_train(input_ids, None if for_classification else labels, depths=depths, global_attention_mask=global_attention_mask)
    return loss_fn(llm_output, depths)

llm_model = create_model(root_path, ModelArgs, Transformer)
# optimizer = AdamWScheduleFree(llm_model.parameters(), weight_decay=0.05, betas=(0.9, 0.98), lr=1e-4, warmup_steps=2000)
optimizer = torch.optim.Adam(llm_model.parameters(), weight_decay=0.1, betas=(0.9, 0.98), lr=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
train_conf = TrainConf(epochs=10, optimizer=optimizer, scheduler=scheduler, loss_fn=loss_fn, ds_loader=ds_loader, eval_model=inference_model)

visualize(train(train_conf, llm_model, train_ds, validation_ds, model_call=call_model))