<a href="https://colab.research.google.com/github/envomp/predicate_logic_training/blob/main/predicate_logic_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install einx dataclasses_json llama_models blobfile

Collecting einx
  Downloading einx-0.3.0-py3-none-any.whl.metadata (6.9 kB)
Collecting dataclasses_json
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting llama_models
  Downloading llama_models-0.0.55-py3-none-any.whl.metadata (8.2 kB)
Collecting blobfile
  Downloading blobfile-3.0.0-py3-none-any.whl.metadata (15 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses_json)
  Downloading marshmallow-3.23.1-py3-none-any.whl.metadata (7.5 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses_json)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting tiktoken (from llama_models)
  Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting pycryptodomex>=3.8 (from blobfile)
  Downloading pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Collecting mypy-extensions>=0.3.0 (from typing-inspect<1,>=0.4.0->dataclasses_json)
  D

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from typing import Optional


@dataclass_json
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1
    output_size: Optional[int] = None
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    meta_embeddings: int = 0

    max_batch_size: int = 32
    max_seq_len: int = 2048

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            if hasattr(self, k):
                setattr(self, k, v)

        if self.n_kv_heads is None:
            self.n_kv_heads = self.n_heads
        assert self.n_kv_heads <= self.n_heads
        assert self.n_heads % self.n_kv_heads == 0
        assert self.dim % self.n_heads == 0

In [4]:
from torch.utils.data import default_collate
import torch


def next_word_prediction_labels(input_ids, answer_ids):
    labels = [-100 for _ in input_ids] + [x for x in answer_ids]
    labels.pop(0)
    input = input_ids + answer_ids
    input.pop(-1)
    return labels, input


def pad_collate(batch, padding, length=None):
    max_size = max([len(x["input_ids"]) for x in batch]) if not length else length
    collated = []

    for elem in batch:
        copy = {}
        copy["input_ids"] = torch.tensor(elem["input_ids"] + [padding] * (max_size - len(elem["input_ids"])))

        if elem["for_classification"]:
            copy["labels"] = torch.tensor(elem["labels"]).float()
        else:
            copy["labels"] = torch.tensor(elem["labels"] + [-100] * (max_size - len(elem["labels"])))

        if elem["segments"]:
            copy["segments"] = torch.tensor(elem["segments"] + [elem["segments"][-1] + 1] * (max_size - len(elem["segments"])))

        if elem["global_attention_mask"]:
            copy["global_attention_mask"] = torch.tensor(elem["global_attention_mask"] + [False] * (max_size - len(elem["global_attention_mask"])))

        copy["depth"] = elem["depth"]
        collated.append(copy)

    return default_collate(collated)

In [5]:
import json
import random

from llama_models.llama3.api import Tokenizer

adjectives = ["wrong", "glamorous", "stormy", "weary", "witty", "tense", "pessimistic", "frightened", "cruel",
              "helpful", "hurt", "comfortable", "worried", "aggressive", "blushing", "proud", "rude", "lucky",
              "fearless", "diplomatic", "hypocritical", "quaint", "perfect", "jittery", "friendly", "horrible",
              "attentive", "victorious", "naughty", "condemned", "fancy", "zealous", "crowded", "sincere", "busy",
              "curious", "talented", "modern", "bright", "mean", "foolish", "tender", "mysterious", "alert", "fine",
              "amused", "bad-tempered", "lonely", "good", "muddy", "frantic", "shy", "versatile", "loving", "elated",
              "powerful", "troubled", "easy", "innocent", "hilarious", "strange", "disobedient", "elegant", "joyous",
              "careless", "shiny", "adorable", "inquisitive", "precious", "wide-eyed", "beautiful", "confused",
              "embarrassed", "famous", "nervous", "plain", "distinct", "courageous", "gleaming", "bored",
              "broad-minded", "fragile", "outstanding", "cooperative", "inexpensive", "charming", "confident", "grumpy",
              "messy", "excited", "long", "talkative", "reserved", "tame", "wandering", "different", "agreeable",
              "cute", "scared", "popular", "exuberant", "impartial", "old-fashioned", "clean", "helpless",
              "thoughtless", "selfish", "serious", "homely", "worrisome", "polite", "tired", "ugly", "light", "ugliest",
              "tidy", "sleepy", "unpleasant", "enchanting", "intellectual", "combative", "rational", "uptight",
              "sensible", "supportive", "spotless", "disgusted", "ambitious", "pleasant", "silly", "clumsy", "average",
              "vivacious", "gifted", "straightforward", "smart", "impatient", "outrageous", "calm", "gorgeous", "frail",
              "dull", "thoughtful", "dishonest", "difficult", "bossy", "stubborn", "anxious", "stupid", "attractive"]
rule_block = 200
deduction_separator = 201
rule_separator = 202
fact_block = 203
query_block = 204
preds_block = 205
end_of_turn = 206
end_of_text = 207
special_tokens = {1: 210, 0: 211}
pad = 208


def process_atoms(facts):
    # return [x for x in facts]
    return [adjectives[int(x)] for x in facts]


def process_rules(rules):
    answer = []
    for facts, deduction in rules:
        facts = process_atoms(facts)
        answer.append(f"{' and '.join(facts)} is {process_atoms([deduction])[0]}")
    return ". ".join(answer)


def llama_tokenize_input(blob, tokenizer: Tokenizer):
    from llama_models.llama3.api import ChatFormat

    message = f"""
facts: {", ".join(process_atoms(blob["facts"]))}
rules: {process_rules(blob["rules"])}
query: {process_atoms(blob["query"][0])[0]}
result: """
    format = ChatFormat(tokenizer)
    input_ids = format.encode_content(message).tokens
    return input_ids


def llama_tokenize_output(blob, tokenizer: Tokenizer, for_classification: bool):
    from llama_models.llama3.api import ChatFormat

    if for_classification:
        return [0, 1] if blob["label"][0] else [1, 0]
    else:
        answer = "True" if blob["label"][0] else "False"
        format = ChatFormat(tokenizer)
        answer_ids = format._encode_content(answer)[0] + [tokenizer.special_tokens["<|eom_id|>"]]
        return answer_ids


def tokenize_input(blob, global_attention_mask=False):
    query = [query_block] + [int(x) for x in blob["query"]]
    facts = [fact_block] + [int(x) for i, x in enumerate(blob["facts"])]
    preds = [preds_block] + [int(x) for i, x in enumerate(blob["preds"])]

    rules = [[int(y) for y in x] + [deduction_separator, int(r)] for i, (x, r) in enumerate(blob["rules"])]
    rules = [rule_block] + [x for i, x in enumerate([y for x in rules for y in x])]
    input_ids = preds + rules + facts + query + [end_of_turn]

    if not global_attention_mask:
        return input_ids
    else:
        return input_ids, [False for _ in input_ids]


def tokenize_output(blob, for_classification: bool):
    if for_classification:
        return [0, 1] if blob["label"][0] else [1, 0]
    else:
        label = list(map(lambda x: special_tokens[int(x)], blob["label"]))
        answer_ids = label + [end_of_text]
        return answer_ids


segment_by = [deduction_separator, rule_separator, rule_block, fact_block, query_block, end_of_turn, end_of_text, pad]


def generate_segments(input_ids):
    segments = []
    for elem in input_ids:
        segments.append(1 if elem in segment_by else 0)
    return segments


def is_reachable(facts, rules, query):
    reachable = set(facts)
    changed = True
    while changed:
        changed = False
        for rule in rules:
            premises, conclusion = rule
            if all(premise in reachable for premise in premises) and conclusion not in reachable:
                reachable.add(conclusion)
                changed = True
    return 1 if query in reachable else 0


def process_labels(blob, expand):
    blob["query"] = blob["preds"] if expand else [blob["query"]]
    blob["label"] = [is_reachable(blob["facts"], blob["rules"], query) for query in blob["preds"]] if expand else [blob["label"]]
    return blob


def load(file, expand=False, tokenizer=None, for_classification=False):
    with open(file, "r") as f:
        data = json.load(f)

    processed = [process_labels(x, expand) for x in data]

    ds = []
    for elem in processed:
        global_attention_mask = None
        if tokenizer is not None:
            inp, out = llama_tokenize_input(elem, tokenizer), llama_tokenize_output(elem, tokenizer, for_classification)
        else:
            (inp, global_attention_mask), out = tokenize_input(elem, global_attention_mask=True), tokenize_output(elem, for_classification)

        # input/output for inference, input_ids/labels for training
        if for_classification:
            data = {"input": inp, "output": out, "input_ids": inp, "labels": out}
        else:
            labels, input_ids = next_word_prediction_labels(inp, out)
            data = {"input": inp, "output": out, "input_ids": input_ids, "labels": labels}

        if global_attention_mask:
            data["global_attention_mask"] = global_attention_mask
        else:
            data["global_attention_mask"] = None

        if not tokenizer:
            data["segments"] = generate_segments(data["input_ids"])
        else:
            data["segments"] = None

        data["depth"] = elem["depth"]
        data["for_classification"] = for_classification

        ds.append(data)
    return ds


def train_curriculum(ds: list, epoch, select_layer_items=500, non_select_layer_items=50):
    samples_by_depth = {i: [] for i in range(7)}
    for elem in ds:
        samples_by_depth[elem["depth"]].append(elem)
    ds.clear()

    if epoch >= 7:
        dump = []
        for i in range(7):
            dump.extend(samples_by_depth[i])
        random.shuffle(dump)
        return dump

    train_ds = []
    for _ in range(select_layer_items):
        train_ds.append(samples_by_depth[epoch].pop(0))
    for j in range(epoch):
        for _ in range(non_select_layer_items):
            train_ds.append(samples_by_depth[j].pop(0))

    for i in range(7):
        ds.extend(samples_by_depth[i])
    return train_ds


def select(ds: list, select_items=500):
    samples_by_depth = {i: [] for i in range(7)}
    for elem in ds:
        samples_by_depth[elem["depth"]].append(elem)
    ds.clear()

    train_ds = []
    for key in samples_by_depth.keys():
        for _ in range(select_items):
            train_ds.append(samples_by_depth[key].pop(0))
    for i in range(7):
        ds.extend(samples_by_depth[i])
    return train_ds


In [6]:
import torch
import matplotlib.pyplot as plt


def eval_model(llm_model, inference_ids, vocabulary=special_tokens, answer_position=0):
    correct = 0
    ood = 0
    false_positive = 0
    false_negative = 0
    positive = 0
    negative = 0
    correct_by_depth = {}
    incorrect_by_depth = {}
    for blob in inference_ids:
        depth = blob["depth"]
        expected = blob["output"][answer_position]
        model_input = torch.tensor([blob["input"]]).cuda()

        with torch.no_grad():
            logits = llm_model(model_input, **blob)
            last_token_logits = logits[:, -1, :]
            next_token_id = torch.argmax(last_token_logits, dim=-1)

        if next_token_id == vocabulary[1]:
            positive += 1
        if next_token_id == vocabulary[0]:
            negative += 1

        if next_token_id not in vocabulary.values():
            ood += 1
        elif expected == next_token_id:
            correct += 1
            if depth not in correct_by_depth:
                correct_by_depth[depth] = 0
            correct_by_depth[depth] += 1
        else:
            if depth not in incorrect_by_depth:
                incorrect_by_depth[depth] = 0
            incorrect_by_depth[depth] += 1
            if next_token_id == vocabulary[1]:
                false_positive += 1
            else:
                false_negative += 1

    print(f"positive predictions: {positive}, negative predictions: {negative}")
    print(f"correct: {correct}, false positive: {false_positive}, false negative: {false_negative}, ood: {ood}")
    print(f"correct by depth: {correct_by_depth}")
    print(f"incorrect by depth: {incorrect_by_depth}")


def visualize_routes(routes):
    for route_idx, route in enumerate(routes):
        plt.plot(list(range(len(route))), route)

    plt.xlabel('Time/Steps')
    plt.ylabel('Layer')
    plt.title('Layer Visitation over Time')
    plt.show()

In [7]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

import math
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn, einsum

import einx
from einops import rearrange, repeat, reduce, pack, unpack


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


def apply_scaling(freqs: torch.Tensor):
    # RoPE scaling (values obtained from grid search)
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * math.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / scale_factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    if use_scaled:
        freqs = apply_scaling(freqs)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


# https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py

def exists(val):
    return val is not None


def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d


def divisible_by(num, den):
    return (num % den) == 0


def pad_at_dim(t, pad: tuple[int, int], dim=-1, value=0.):
    if pad == (0, 0):
        return t

    dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = ((0, 0) * dims_from_right)
    return F.pad(t, (*zeros, *pad), value=value)


def l2norm(t, groups=1):
    t = rearrange(t, '... (g d) -> ... g d', g=groups)
    t = F.normalize(t, p=2, dim=-1)
    return rearrange(t, '... g d -> ... (g d)')


class always():
    def __init__(self, val):
        self.val = val

    def __call__(self, *args, **kwargs):
        return self.val


# positional embeddings

class AbsolutePositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len, l2norm_embed=False):
        super().__init__()
        self.scale = dim ** -0.5 if not l2norm_embed else 1.
        self.max_seq_len = max_seq_len
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(max_seq_len, dim)

    def forward(self, x, pos=None, seq_start_pos=None):
        seq_len, device = x.shape[1], x.device
        assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'

        if not exists(pos):
            pos = torch.arange(seq_len, device=device)

        if exists(seq_start_pos):
            pos = (pos - seq_start_pos[..., None]).clamp(min=0)

        pos_emb = self.emb(pos)
        pos_emb = pos_emb * self.scale
        return l2norm(pos_emb) if self.l2norm_embed else pos_emb


class ScaledSinusoidalEmbedding(nn.Module):
    def __init__(self, dim, theta=10000):
        super().__init__()
        assert divisible_by(dim, 2)
        self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)

        half_dim = dim // 2
        freq_seq = torch.arange(half_dim).float() / half_dim
        inv_freq = theta ** -freq_seq
        self.register_buffer('inv_freq', inv_freq, persistent=False)

    def forward(self, x, pos=None, seq_start_pos=None):
        seq_len, device = x.shape[1], x.device

        if not exists(pos):
            pos = torch.arange(seq_len, device=device)

        if exists(seq_start_pos):
            pos = pos - seq_start_pos[..., None]

        emb = einsum('i, j -> i j', pos, self.inv_freq)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb * self.scale


class RelativePositionBias(nn.Module):
    def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
        super().__init__()
        self.scale = scale
        self.causal = causal
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
        ret = 0
        n = -relative_position
        if not causal:
            num_buckets //= 2
            ret += (n < 0).long() * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, i, j):
        device = self.device
        q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
        k_pos = torch.arange(j, dtype=torch.long, device=device)
        rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
        rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance)
        values = self.relative_attention_bias(rp_bucket)
        bias = rearrange(values, 'i j h -> h i j')
        return bias * self.scale


class AlibiPositionalBias(nn.Module):
    def __init__(self, heads, total_heads=None, slopes: list[int] | None = None):
        super().__init__()
        self.heads = heads
        self.total_heads = default(total_heads, heads)

        slopes = torch.Tensor(default(slopes, self._get_slopes(heads)))
        slopes = rearrange(slopes, 'h -> h 1 1')

        self.register_buffer('slopes', slopes, persistent=False)
        self.register_buffer('bias', None, persistent=False)

    @property
    def device(self):
        return next(self.buffers()).device

    @staticmethod
    def _get_slopes(heads):
        def get_slopes_power_of_2(n):
            start = (2 ** (-2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio ** i for i in range(n)]

        if math.log2(heads).is_integer():
            return get_slopes_power_of_2(heads)

        closest_power_of_2 = 2 ** math.floor(math.log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads - closest_power_of_2]

    def forward_custom_pos(self, pos_i: torch.Tensor, pos_j: torch.Tensor | None = None):
        h, device = self.total_heads, self.device

        pos_j = default(pos_j, pos_i)
        bias = -einx.subtract('... j, ... i -> ... i j', pos_j, pos_i).abs()

        if bias.ndim == 3:
            bias = rearrange(bias, 'b i j -> b 1 i j')

        bias = bias * self.slopes
        num_heads_unalibied = h - bias.shape[-3]
        bias = pad_at_dim(bias, (0, num_heads_unalibied), dim=-3)

        return bias

    def forward(self, i, j):
        h, device = self.total_heads, self.device

        if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
            return self.bias[..., -i:, -j:]

        seq_arange = torch.arange(j - i, j, device=device)
        context_arange = torch.arange(j, device=device)
        bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()

        bias = bias * self.slopes
        num_heads_unalibied = h - bias.shape[-3]
        bias = pad_at_dim(bias, (0, num_heads_unalibied), dim=-3)

        self.register_buffer('bias', bias, persistent=False)
        return self.bias


class DynamicPositionBias(nn.Module):
    def __init__(self, dim, *, heads, depth, log_distance=False, norm=False):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.log_distance = log_distance

        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(nn.Linear(1, dim), nn.LayerNorm(dim) if norm else None, nn.SiLU()))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim) if norm else None, nn.SiLU()))

        self.mlp.append(nn.Linear(dim, heads))

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, i, j):
        assert i == j
        n, device = j, self.device

        # get the (n x n) matrix of distances
        seq_arange = torch.arange(n, device=device)
        context_arange = torch.arange(n, device=device)
        indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
        indices += (n - 1)

        # input to continuous positions MLP
        pos = torch.arange(-n + 1, n, device=device).bfloat16()
        pos = rearrange(pos, '... -> ... 1')

        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

        for layer in self.mlp:
            pos = layer(pos)

        # get position biases
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias


# designed for causal
class CoPE(nn.Module):
    """
    Appendix B of https://arxiv.org/abs/2405.18719
    """

    def __init__(
            self,
            dim,
            heads,
            max_pos,
            soft_onehot=False,
            talking_heads=False,
            soft_onehot_temp=5e-2
    ):
        super().__init__()
        self.max_pos = max_pos
        self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))

        self.talking_heads = nn.Conv2d(heads, heads, 1, bias=False) if talking_heads else None
        self.soft_onehot = soft_onehot
        self.soft_onehot_temp = soft_onehot_temp

        if not soft_onehot:
            return

        self.register_buffer('positions', torch.arange(max_pos))

    def forward(self, query, attn_logits):

        if exists(self.talking_heads):
            i, j = attn_logits.shape[-2:]
            causal_mask = attn_logits.new_ones(i, j).triu_(j - i + 1).bool()

            attn_logits = self.talking_heads(attn_logits)

            attn_logits = attn_logits.masked_fill(causal_mask, -torch.finfo(attn_logits.dtype).max)

        # compute positions

        gates = attn_logits.sigmoid()

        pos = gates.flip(-1).cumsum(dim=-1).flip(-1)
        pos = pos.clamp(max=self.max_pos - 1)

        logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)

        if self.soft_onehot:
            diff_pos = einx.subtract('i, j -> i j', pos, self.positions).abs()
            soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim=-1)
            cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
        else:
            # interpolate from integer positions
            pos_ceil = pos.ceil().long()
            pos_floor = pos.floor().long()
            logits_ceil = logits_int.gather(-1, pos_ceil)
            logits_floor = logits_int.gather(-1, pos_floor)

            w = pos - pos_floor
            cope_pos_emb = logits_ceil * w + logits_floor * (1 - w)

        return cope_pos_emb


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_local_heads = args.n_heads
        self.n_local_kv_heads = self.n_kv_heads
        self.alibi_heads = self.n_local_kv_heads
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.dim = args.dim

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        assert self.alibi_heads <= self.n_local_heads, 'number of ALiBi heads must be less than the total number of heads'
        # self.rel_pos = AlibiPositionalBias(heads=self.alibi_heads, total_heads=self.n_local_heads)
        # self.rel_pos = DynamicPositionBias(dim=self.dim // 4, heads=self.n_local_heads, log_distance=False, depth=2, norm=True)
        # self.rel_pos = RelativePositionBias(scale=self.head_dim ** 0.5, causal=False, heads=self.n_local_heads, num_buckets=32, max_distance=128)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor] = None):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # repeat k/v heads if n_kv_heads < n_heads (GQA)
        xk = repeat_kv(xk, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        xv = repeat_kv(xv, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        # make heads be a batch dim
        xq, xk, xv = (x.transpose(1, 2) for x in (xq, xk, xv))

        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
        # scores = scores + self.rel_pos(xq.shape[-2], xk.shape[-2]).to(scores)

        if mask is not None:
            scores = scores + mask.unsqueeze(1)  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)

        # concatenate all the heads
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        # output projection
        proj = self.wo(output)
        return proj


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        # self.gate = nn.Sequential(
        #     nn.Linear(args.dim, args.dim, bias=False),
        #     nn.SiLU(),
        #     nn.Linear(args.dim, 1, bias=False),
        #     nn.Sigmoid()
        # )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=1e-05)
        self.ffn_norm = RMSNorm(args.dim, eps=1e-05)
        self.gate_norm = RMSNorm(args.dim, eps=1e-05)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor] = None):
        attention_out = self.attention(self.attention_norm(x), freqs_cis, mask=mask)
        # gate_out = self.gate(self.gate_norm(attention_out))
        h = x + attention_out # * gate_out
        out = h + self.feed_forward(self.ffn_norm(h))
        return out#, gate_out.squeeze(-1)


class Transformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
        self.meta_tokens = params.meta_embeddings
        if self.meta_tokens:
            self.meta_token_embeddings = torch.nn.Parameter(torch.randn(params.meta_embeddings, params.dim))

        nn.init.normal_(self.tok_embeddings.weight, mean=0.0, std=1.0 / params.dim ** 0.5)
        # self.positional_embeddings = ScaledSinusoidalEmbedding(params.dim)
        # self.positional_embeddings = AbsolutePositionalEmbedding(params.dim, params.max_seq_len, l2norm_embed=False)
        self.positional_embeddings = always(0)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=1e-05)
        self.output = nn.Linear(params.dim, params.output_size if params.output_size else params.vocab_size, bias=False)

        self.freqs_cis = precompute_freqs_cis(params.dim // params.n_heads, params.max_seq_len * 2, 500000.0, True)

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, labels=None, depths=None, global_attention_mask: Optional[torch.Tensor] = None):
        return self.hierarchical_forward(tokens, initial_labels=labels, depths=depths, global_attention_mask=global_attention_mask, is_train=False)

    def forward_train(self, tokens: torch.Tensor, labels: torch.Tensor, depths, global_attention_mask: Optional[torch.Tensor] = None):
        return self.hierarchical_forward(tokens, initial_labels=labels, depths=depths, global_attention_mask=global_attention_mask, is_train=True)

    def hierarchical_forward(self, tokens: torch.Tensor, initial_labels: torch.Tensor = None, depths: torch.Tensor = None, global_attention_mask: Optional[torch.Tensor] = None, is_train: bool = None):
        _bsz, seqlen = tokens.shape
        seqlen += self.meta_tokens

        ignore_labels = torch.full((_bsz, self.meta_tokens), -100, dtype=torch.int, device=initial_labels.device)
        labels = torch.cat([ignore_labels, initial_labels], dim=1)

        e = self.tok_embeddings(tokens)
        h = torch.cat([self.meta_token_embeddings.unsqueeze(0).expand(_bsz, -1, -1), e], dim=1) if self.meta_tokens else e
        h = h + self.positional_embeddings(h)

        self.freqs_cis = self.freqs_cis.clone().to(h.device)
        freqs_cis = self.freqs_cis[:seqlen]

        mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
        mask = torch.triu(mask, diagonal=1)
        mask = mask.type_as(h)

        # mask, w = torch.ones(seqlen, seqlen, device=tokens.device), 31 # needs to be odd
        # mask = torch.triu(mask, diagonal=-w // 2) - torch.triu(mask, diagonal=w // 2 + 1)
        # mask = torch.tril(mask)
        # mask[mask == 0] = float("-inf")

        if global_attention_mask is not None:
            meta_mask = torch.ones(_bsz, self.meta_tokens, dtype=torch.bool, device=global_attention_mask.device)
            mask = merge_masks(mask, torch.cat([meta_mask, global_attention_mask], dim=1), labels)
        else:
            mask = mask.unsqueeze(0).repeat(_bsz, 1, 1)

        current_layer_indexes = torch.zeros(_bsz, device=tokens.device).type(dtype=torch.long)
        invocations = torch.zeros(_bsz, device=tokens.device).type(dtype=torch.long)
        history = []
        all_next_layer_probs = [[] for _ in range(tokens.shape[0])]

        while True:
            history.append(current_layer_indexes.clone().detach())
            active_sequences_mask = (current_layer_indexes != self.params.n_layers)

            if not active_sequences_mask.any():
                break

            unique_layer_indices = torch.unique(current_layer_indexes[active_sequences_mask])

            for layer_index in unique_layer_indices:
                i = layer_index.item()
                this_batch = (current_layer_indexes == i) & active_sequences_mask
                batch_h = h[this_batch]
                batch_mask = mask[this_batch]
                layer = self.layers[i]
                # h[this_batch], next_layer_probs = layer(batch_h, freqs_cis, mask=batch_mask)
                h[this_batch] = layer(batch_h, freqs_cis, mask=batch_mask)

                # last_layer_prob = 0.8
                # step = 0.2
                # layer_threshold = last_layer_prob - (self.params.n_layers * step) + step * (i + 1)
                layer_threshold = 0.7
                training_stability = -0.05 * invocations
                layer_threshold = training_stability + layer_threshold

                # if labels is None or not is_train:
                #     layer_probs = next_layer_probs[:, -1]
                #     increment_decrement = torch.where(layer_probs > layer_threshold[this_batch], 1, -1)
                # else:
                #     layer_probs = torch.where(labels[this_batch] != -100, next_layer_probs, 2)
                #     increment_decrement = torch.all(layer_probs > layer_threshold[this_batch].unsqueeze(1), dim=1).long() * 2 - 1
                #
                # for i, val in enumerate(this_batch):
                #     if val:
                #         probs = layer_probs[0].unsqueeze(0)
                #         all_next_layer_probs[i].append(probs[probs <= 1.1])
                #         layer_probs = layer_probs[1:]

                if depths is not None:
                    pos = invocations[this_batch]
                    depth = depths[this_batch]
                    increment_decrement = torch.tensor([is_forward_at_position_for_depth(p.item(), d.item()) for p, d in zip(pos, depth)], device=tokens.device)

                current_layer_indexes[this_batch] += increment_decrement
                current_layer_indexes = current_layer_indexes.clamp(min=0)

                invocations[this_batch] += 1

        h = self.norm(h)
        output = self.output(h).float()
        return output, labels, all_next_layer_probs, history


def merge_masks(local_mask: torch.Tensor, global_attention_mask: torch.Tensor, labels: torch.Tensor):
    batch_size, seq_len = global_attention_mask.size()
    expanded_local_mask = local_mask.unsqueeze(0).expand(batch_size, -1, -1)

    global_mask = global_attention_mask[:, :, None] + global_attention_mask[:, None, :]

    first_label_indices = (labels != -100).long().argmax(dim=1)
    pad_mask = torch.arange(seq_len, device=labels.device)[None, :] > first_label_indices[:, None]

    global_mask[pad_mask.unsqueeze(1).expand(-1, seq_len, -1)] = 0

    combined_mask = expanded_local_mask.clone()
    combined_mask[global_mask > 0.1] = 1
    return combined_mask


def is_forward_at_position_for_depth(pos, depth):
    seq = [1, 1, 1] #+ [-1, -1, 1, 1] * max(0, depth - 1) + [1]
    return seq[min(pos, len(seq) - 1)]


In [8]:
import torch
from torch import nn
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from typing import Optional, Callable
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


def create_model(root_path, args_class, model_class, state_dict_location=None):
    conf_path = root_path + "/params.json"
    with open(conf_path, 'r') as file:
        conf = args_class.from_json(file.read())
        print(conf)

    with torch.device('cpu'):
        llm_model = model_class(params=conf)
        if state_dict_location is not None:
            state_dict = torch.load(state_dict_location, map_location=torch.device("cpu"), weights_only=True)
            res = llm_model.load_state_dict(state_dict, strict=False, assign=True)
            print(res)

    return llm_model


@dataclass_json
@dataclass
class TrainConf:
    epochs: int = 1
    accumulation_batches: int = 1
    optimizer: torch.optim.Optimizer = None
    scheduler: torch.optim.lr_scheduler.LRScheduler = None
    loss_fn: Callable[..., torch.Tensor] = None
    ds_loader: Callable[[list, int, int], DataLoader] = None
    eval_model: Optional[Callable[[nn.Module, int], None]] = None

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            if hasattr(self, k):
                setattr(self, k, v)


def get_total_grad_norm(llm_model):
    total_norm = 0
    for name, p in llm_model.named_parameters():
        if p.grad is None:
            continue
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    return total_norm


def call_model(llm, batch, loss_fn):
    input_ids = batch["input_ids"].to("cuda")
    labels = batch["labels"].to("cuda")

    llm_output = llm.forward_train(input_ids)
    return loss_fn(labels, llm_output)


def train(conf: TrainConf, llm_model: nn.Module, train_ds: list, validation_ds: list, model_call=call_model):
    torch.autograd.set_detect_anomaly(True)
    training_loss_curves = []

    train_length = len(train_ds)
    validation_length = len(validation_ds)

    for epoch in range(conf.epochs):
        data_loader = conf.ds_loader(train_ds, train_length, epoch)
        validation_loader = conf.ds_loader(validation_ds, validation_length, epoch)
        print(f"\nEPOCH {epoch}: len(train)={len(data_loader)}, len(validation)={len(validation_loader)}")

        llm_model.train()
        # conf.optimizer.train()
        training_loss_curve = []
        gradient_norm = []

        for i, batch in enumerate(data_loader):
            loss = model_call(llm_model, batch, conf.loss_fn)
            loss.backward()

            total_norm = get_total_grad_norm(llm_model)
            gradient_norm.append(round(total_norm ** 0.5, 4))

            if (i + 1) % conf.accumulation_batches == 0:
                conf.optimizer.step()
                conf.optimizer.zero_grad()
            training_loss_curve.append(round(loss.item() / len(batch), 4))

        avg_tloss = round(sum(training_loss_curve) / len(training_loss_curve), 4)
        training_loss_curves.append(training_loss_curve)

        running_vloss = 0.0
        llm_model.eval()
        # conf.optimizer.eval()
        with torch.no_grad():
            for i, batch in enumerate(validation_loader):
                loss = model_call(llm_model, batch, conf.loss_fn)
                running_vloss += loss.item()
        avg_vloss = round(running_vloss / len(validation_loader), 4)
        print(f"LOSS train {avg_tloss} valid {avg_vloss} lr {conf.scheduler.get_last_lr()} "
              f"min gradient norm: {min(gradient_norm)} max gradient norm: {max(gradient_norm)}")

        if conf.eval_model is not None:
            conf.eval_model(llm_model, epoch)

        conf.scheduler.step()

    return training_loss_curves


def visualize(nested_losses):
    flattened_list = []
    indices = []
    current_index = 1
    for sublist in nested_losses:
        flattened_list.extend(sublist)
        indices.extend([current_index] * len(sublist))
        current_index += 1

    total_items = len(flattened_list)
    step = max(total_items // 1000, 1)

    selected_items = flattened_list[step - 1::step]
    selected_indices = indices[step - 1::step]

    x_values = range(1, len(selected_items) + 1)

    plt.scatter(x_values, selected_items, c=selected_indices, cmap='viridis')
    plt.xlabel('Time')
    plt.ylabel('Loss')
    plt.yscale("log")
    plt.title('Loss curve')
    plt.colorbar(label='Epoch')
    plt.show()


In [None]:
for_classification = False
tokenizer_path = "/content/drive/MyDrive/ML/Llama3.2-3B-Instruct/tokenizer.model"
tokenizer = Tokenizer(tokenizer_path)

train_ds = load("/content/drive/MyDrive/ML/predicate_logic/train/7k/prop_examples_lp.txt", for_classification=for_classification, tokenizer=tokenizer)
validation_ds = load("/content/drive/MyDrive/ML/predicate_logic/validation/prop_examples_lp.txt", for_classification=for_classification, tokenizer=tokenizer)
inference_ids_lp = load("/content/drive/MyDrive/ML/predicate_logic/validation/prop_examples_lp.txt", for_classification=for_classification, tokenizer=tokenizer)
inference_ids_lp_star = load("/content/drive/MyDrive/ML/predicate_logic/validation/prop_examples_lp_star.txt", for_classification=for_classification, tokenizer=tokenizer)
inference_ids_rp = load("/content/drive/MyDrive/ML/predicate_logic/validation/prop_examples_rp.txt", for_classification=for_classification, tokenizer=tokenizer)

def ds_loader(ds, ds_length, epoch):
    # curriculum = processor.train_curriculum(ds, epoch, select_layer_items=ds_length // 7, non_select_layer_items=ds_length // 140)
    return DataLoader(ds, shuffle=True, batch_size=4, collate_fn=lambda x: pad_collate(x, padding=pad))

def loss_fn(llm_output, depth):
    logits, labels, gating, history = llm_output

    if for_classification:
        logit = logits[:, -1]
        total_loss = F.binary_cross_entropy_with_logits(logit, labels)
    else:
        total_loss = F.cross_entropy(logits.transpose(1, 2), labels, label_smoothing=0.)

    # layers_traversed_loss = 0.0
    # for next_layer_prob in gating:
    #     layers_traversed_loss += F.binary_cross_entropy(next_layer_prob, torch.ones_like(next_layer_prob))

    # # whatever you think, be confident in it
    # confidence_loss = 0.0
    # for next_layer_prob in gating:
    #     confidence = 0.5 - torch.abs(next_layer_prob - 0.5)
    #     confidence_loss += confidence.mean()
    # confidence_loss *= 0.01
    # confidence_loss /= len(gating)
    # total_loss += confidence_loss

    # fixed routing
    # total_routing_loss = 0
    # for pos, gates in enumerate(gating):
    #     routing_loss = 0
    #     for i in range(len(gates)):
    #         next_layer_prob = gates[i]
    #         is_forward = is_forward_at_position_for_depth(i, depth[pos])
    #         target_prob = torch.ones_like(next_layer_prob) if is_forward > 0 else torch.zeros_like(next_layer_prob)
    #         routing_loss += F.binary_cross_entropy(next_layer_prob, target_prob)
    #     total_routing_loss += routing_loss / len(gates)
    # total_loss += total_routing_loss / len(gating)

    return total_loss

def inference_model(llm_model, epoch=0):
    if for_classification:
        vocabulary, answer_position = {1: 1, 0: 0}, 1
    else:
        vocabulary, answer_position = special_tokens, 0

    with torch.no_grad():
        routes = []
        def model_invocation(xs, labels, depth, **kwargs):
            labels = torch.tensor([labels]).to("cuda")
            depths = torch.tensor([depth]).to("cuda")
            logits, labels, gates, history = llm_model(xs, labels=labels, depths=depths)
            routes.append([x.item() for x in history])
            return logits
        # eval_model(model_invocation, [x for x in inference_ids if x["depth"] <= epoch], vocabulary=vocabulary, answer_position=answer_position)
        visualize_routes(routes)
        eval_model(model_invocation, inference_ids_lp, vocabulary=vocabulary, answer_position=answer_position)
        eval_model(model_invocation, inference_ids_lp_star, vocabulary=vocabulary, answer_position=answer_position)
        eval_model(model_invocation, inference_ids_rp, vocabulary=vocabulary, answer_position=answer_position)

def call_model(llm, batch, loss_fn):
    input_ids = batch["input_ids"].to("cuda")
    labels = batch["labels"].to("cuda")
    depths = batch["depth"].to("cuda")

    llm_output = llm.forward_train(input_ids, None if for_classification else labels, depths=depths)
    return loss_fn(llm_output, depths)

root_path = "/content/drive/MyDrive/ML/predicate_logic/llama_predicate_logic_hierarchical_routing"
state_dict_location = "/content/drive/MyDrive/ML/Llama3.2-3B-Instruct/consolidated.00.pth"
llm_model = create_model(root_path, ModelArgs, Transformer, state_dict_location=state_dict_location)
for name, parameter in llm_model.named_parameters():
    parameter.requires_grad = True
    parameter.data = parameter.data.bfloat16().to("cuda")
    print(f"{name} (device={parameter.device}, requires_grad={parameter.requires_grad}, dtype={parameter.dtype})")

# optimizer = AdamWScheduleFree(llm_model.parameters(), weight_decay=0.05, betas=(0.9, 0.98), lr=1e-4, warmup_steps=2000)
optimizer = torch.optim.AdamW(llm_model.parameters(), weight_decay=0.1, betas=(0.9, 0.999), lr=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
train_conf = TrainConf(epochs=5, optimizer=optimizer, scheduler=scheduler, loss_fn=loss_fn, ds_loader=ds_loader, eval_model=inference_model)

visualize(train(train_conf, llm_model, train_ds, validation_ds, model_call=call_model))

ModelArgs(dim=3072, n_layers=28, n_heads=24, n_kv_heads=8, vocab_size=128256, output_size=128256, multiple_of=256, ffn_dim_multiplier=1.0, meta_embeddings=0, max_batch_size=32, max_seq_len=1024)
