In [58]:
from datasets import load_from_disk, load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import numpy as np
from transformers import DataCollatorForLanguageModeling, TrainingArguments, Trainer, AutoModelForMaskedLM, BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, AutoModel, PretrainedConfig, AutoConfig, AutoModelForSequenceClassification

import sys 
import os
sys.path.append(os.path.abspath(".."))

from collate import DataCollatorForLanguageModelingSpan


dataset = load_from_disk("../batch128")
dataset = dataset.remove_columns(["species_name", "__index_level_0__"])


tokenizer = AutoTokenizer.from_pretrained("gagneurlab/SpeciesLM", revision="downstream_species_lm")

# This way we don't load weights
# https://stackoverflow.com/questions/65072694/make-sure-bert-model-does-not-load-pretrained-weights
# TODO AutConfig or AutoModel? i guess it doesn't matter

#config = PretrainedConfig.from_pretrained("togethercomputer/m2-bert-80M-2k")
#model = BertForMaskedLM(config)
model = AutoModelForSequenceClassification.from_pretrained(
  "togethercomputer/m2-bert-80M-2k",
  trust_remote_code=True
)


                

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
You are using a model of type m2_bert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


-- Bidirectional: True
-- Using Long Conv Residual: True
-- Hyena w: 10
-- Hyena w mod: 1
-- Hyena filter order: 128
-- Hyena filter dropout: 0.2
-- Hyena filter wd: 0.1
-- Hyena filter emb dim: 5
-- Hyena filter lr: 0.001
-- Hyena filter lr pos emb: 1e-05


Some weights of the model checkpoint at togethercomputer/m2-bert-80M-2k were not used when initializing BertForMaskedLM: ['model.bert.encoder.layer.4.attention.filter_fn2.implicit_filter.4.weight', 'model.bert.encoder.layer.4.attention.filter_fn.pos_emb.t', 'model.bert.encoder.layer.0.attention.filter_fn2.implicit_filter.2.weight', 'model.bert.encoder.layer.4.mlp.layernorm.weight', 'model.bert.encoder.layer.1.attention.filter_fn2.implicit_filter_rev.4.bias', 'model.bert.encoder.layer.10.attention.filter_fn.implicit_filter.5.freq', 'model.bert.encoder.layer.9.attention.filter_fn.pos_emb.z', 'model.bert.encoder.layer.6.attention.filter_fn2.implicit_filter_rev.4.weight', 'model.bert.encoder.layer.11.attention.filter_fn.implicit_filter.5.freq', 'model.bert.embeddings.token_type_embeddings.weight', 'model.bert.encoder.layer.6.attention.filter_fn.implicit_filter.1.freq', 'model.bert.encoder.layer.4.attention.filter_fn2.implicit_filter.5.freq', 'model.bert.encoder.layer.1.attention.filter_fn2

In [54]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30528, 768, padding_idx=0)
      (position_embeddings): Embedding(2048, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): MonarchMixerSequenceMixing(
            (filter_fn): HyenaFilter(
              (dropout): Dropout(p=0.2, inplace=False)
              (pos_emb): PositionalEmbedding()
              (implicit_filter): Sequential(
                (0): Linear(in_features=5, out_features=128, bias=True)
                (1): Sin()
                (2): Linear(in_features=128, out_features=128, bias=True)
                (3): Sin()
                (4): Linear(in_features=128, out_features=128, bias=True)
                (5): Sin()
                (6): Line

## Trying to reinitialize the model weights

In [55]:
print(model.bert.encoder.layer[0].attention.filter_fn.implicit_filter[0].weight[0])
print(model.bert.encoder.layer[0].attention.short_filter.weight[0])

tensor([ 0.0056, -0.0145, -0.0089, -0.0047, -0.0043],
       grad_fn=<SelectBackward0>)
tensor([[-0.0164,  0.0040,  0.0007]], grad_fn=<SelectBackward0>)


In [None]:
modules = model.named_modules()

In [56]:
import torch

# using this function from stackoverflow
# added randomizing convolutions
def randomize_model(model):
    for module_ in model.named_modules(): 
        if isinstance(module_[1],(torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv1d)):
            module_[1].weight.data.normal_(mean=0.0, std=model.config.initializer_range)
        elif isinstance(module_[1], torch.nn.LayerNorm):
            module_[1].bias.data.zero_()
            module_[1].weight.data.fill_(1.0)
        if isinstance(module_[1], (torch.nn.Linear, torch.nn.Conv1d)) and module_[1].bias is not None:
            module_[1].bias.data.zero_()
    return model

randomize_model(model)

print(model.bert.encoder.layer[0].attention.filter_fn.implicit_filter[0].weight[0])
print(model.bert.encoder.layer[0].attention.short_filter.weight[0])

tensor([ 0.0075, -0.0093,  0.0100, -0.0018,  0.0285],
       grad_fn=<SelectBackward0>)
tensor([[ 0.0049, -0.0237,  0.0216]], grad_fn=<SelectBackward0>)


## Trying out if we still get the loss to go to zero

In [59]:
#Load model directly
%env WANDB_PROJECT=singlesamplednam2

dataset = load_from_disk("../microset")
dataset = dataset.remove_columns(["species_name", "__index_level_0__"])

data_collator = DataCollatorForLanguageModelingSpan(tokenizer, mlm=True, mlm_probability = 0.02, span_length = 6)

training_args = TrainingArguments(
    output_dir="./results/correct_model_1sample",
    
    max_steps=20000,
    
    seed=17,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=16,
    
    logging_strategy="steps",
    logging_steps=1,
    
    evaluation_strategy="no",
    
    #dataloader_num_workers=4,
    #dataloader_prefetch_factor=2,
    run_name="correct_model_1sample",
    report_to="none" #"wandb"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator
)

trainer.train()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


env: WANDB_PROJECT=singlesamplednam2


IndexError: too many indices for tensor of dimension 2

In [62]:
train_dataloader = DataLoader(dataset["train"], batch_size=1, shuffle=True)
sample = next(iter(train_dataloader))
sample

{'input_ids': [tensor([2]),
  tensor([4837]),
  tensor([1092]),
  tensor([260]),
  tensor([1025]),
  tensor([4085]),
  tensor([4037]),
  tensor([3847]),
  tensor([3086]),
  tensor([44]),
  tensor([161]),
  tensor([631]),
  tensor([2509]),
  tensor([1829]),
  tensor([3207]),
  tensor([526]),
  tensor([2090]),
  tensor([153]),
  tensor([598]),
  tensor([2379]),
  tensor([1309]),
  tensor([1128]),
  tensor([401]),
  tensor([1592]),
  tensor([2257]),
  tensor([822]),
  tensor([3276]),
  tensor([801]),
  tensor([3190]),
  tensor([458]),
  tensor([1820]),
  tensor([3170]),
  tensor([379]),
  tensor([1503]),
  tensor([1901]),
  tensor([3496]),
  tensor([1684]),
  tensor([2626]),
  tensor([2299]),
  tensor([992]),
  tensor([3955]),
  tensor([3518]),
  tensor([1772]),
  tensor([2977]),
  tensor([3701]),
  tensor([2501]),
  tensor([1800]),
  tensor([3091]),
  tensor([64]),
  tensor([241]),
  tensor([949]),
  tensor([3783]),
  tensor([2831]),
  tensor([3120]),
  tensor([179]),
  tensor([702]),
  