In [1]:
from models import *
from jepa import JEPA
from datasets import H5Dataset, SparseDataset, METADATA
from trainer import Trainer

In [2]:
mode = 'cite'
days = [2]

train_dataset = H5Dataset('train', mode, days=days)
train_dataloader = train_dataset.get_dataloader(batch_size=16)

val_dataset = H5Dataset('val', mode, days=days)
val_dataloader = val_dataset.get_dataloader(batch_size=16)

In [3]:
latent_dim = 1024
n_chan = 256

in_encoder = Encoder(22050, latent_dim, 256, 6, 11, 'enformer', 'attention')
out_decoder = Decoder(latent_dim, 140, 256, 6, 11, 'enformer', 'conv')
predictor = LinearCoder([latent_dim, latent_dim], 256, True, 0.05)
models = {'in_encoder': in_encoder,
         'out_decoder': out_decoder,
         'predictor': predictor}
jepa = JEPA(models)
print(jepa)

Joint Embedding (Predictive) Architecture with the following models:
	in_encoder with 7917054 parameters
	out_decoder with 8374679 parameters
	predictor with 2955776 parameters



In [4]:
# example pipeline for JEPA inference
for (x, day), y in train_dataloader:  # grab one batch
    x, day = x, day
    y = y
    break 
print(jepa.infer(x, day).shape, y.shape)  # these better be the same shape lol

torch.Size([16, 140]) torch.Size([16, 140])


In [5]:
initial_lrs = {'in_encoder': 0.04,
             'out_decoder': 0.04,
             'predictor': 0.04}
lr_decay_periods = {'in_encoder': 5,
             'out_decoder': 5,
             'predictor': 5}
lr_decay_gammas = {'in_encoder': 0.5,
             'out_decoder': 0.5,
             'predictor': 0.5}
weight_decays = {'in_encoder': 0.0001,
             'out_decoder': 0.0001,
             'predictor': 0.0001}
num_epochs = 20
eval_every = 1
patience = 2
num_tries = 1

trainer = Trainer(jepa=jepa, 
                  train_dataloader=train_dataloader, 
                  val_dataloader=val_dataloader, 
                  initial_lrs=initial_lrs,
                  lr_decay_gammas=lr_decay_gammas,
                  lr_decay_periods=lr_decay_periods,
                  weight_decays=weight_decays)
jepa = trainer.train(num_epochs, eval_every, patience, num_tries)

  0%|                                                   | 0/960 [00:00<?, ?it/s]

Using cpu for training
Adjusting learning rate of group 0 to 4.0000e-02.
Adjusting learning rate of group 0 to 4.0000e-02.
Adjusting learning rate of group 0 to 4.0000e-02.
Training the following ensemble for 20 epochs:

Joint Embedding (Predictive) Architecture with the following models:
	in_encoder with 7917054 parameters
	out_decoder with 8374679 parameters
	predictor with 2955776 parameters


-------------------------------------------------------------
------------------  TRAIN - EPOCH NUM 0  -------------------
-------------------------------------------------------------


  0%|                                         | 1/960 [00:23<6:08:51, 23.08s/it][E thread_pool.cpp:113] Exception in thread pool task: mutex lock failed: Invalid argument
  0%|                                         | 1/960 [00:29<7:57:42, 29.89s/it]
  0%|                                                   | 0/273 [00:00<?, ?it/s]


Catching keyboard interrupt!!!

-------------------------------------------------------------
-------------------  VAL - EPOCH NUM 0  -------------------
-------------------------------------------------------------


AttributeError: 'JEPA' object has no attribute '_infer'

In [None]:
for (x, day), y in train_dataloader:
    out = jepa.infer(x, day)
    print(out[0])
    print(y[0])
    break