In [1]:
import os
import pandas as pd
import os, gc
import numpy as np
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast

import neptune
from neptune.utils import stringify_unsupported
from tqdm import tqdm, notebook
import transformers
from collections import defaultdict
import glob

import sys
import argparse
from copy import copy
import importlib

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!pip install graphviz
from torchviz import make_dot


Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [3]:
BASEDIR= './'#'../input/asl-fingerspelling-config'
for DIRNAME in 'configs data models postprocess metrics utils repos'.split():
    sys.path.append(f'{BASEDIR}/{DIRNAME}/')

parser = argparse.ArgumentParser(description="")

parser.add_argument("-C", "--config", help="config filename", default="cfg_0")
parser.add_argument("-G", "--gpu_id", default="", help="GPU ID")
parser_args, other_args = parser.parse_known_args(sys.argv)
cfg = copy(importlib.import_module(parser_args.config).cfg)

try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary


In [4]:
df = pd.read_parquet(cfg.train_df)
LenMatchBatchSampler = importlib.import_module(cfg.dataset).LenMatchBatchSampler
DeviceDataLoader = importlib.import_module(cfg.dataset).DeviceDataLoader
Squeezeformer_RNA = importlib.import_module(cfg.model).Squeezeformer_RNA
BPPs_RNA_Dataset = importlib.import_module(cfg.dataset).BPPs_RNA_Dataset


In [5]:
fold=cfg.fold
nfolds=cfg.nfolds

In [6]:
ds_train = BPPs_RNA_Dataset(df, mode='train', fold=fold, nfolds = nfolds)
ds_train_len = BPPs_RNA_Dataset(df, mode='train', fold=fold, 
            nfolds=nfolds, mask_only=True)
sampler_train = torch.utils.data.RandomSampler(ds_train_len)
len_sampler_train = LenMatchBatchSampler(sampler_train, batch_size=cfg.bs,
            drop_last=True)
dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train,
                                                        batch_sampler=len_sampler_train,
                                                        num_workers=cfg.num_workers,
                                                        persistent_workers=True),
                                                        cfg.device)

In [7]:
sample =next(iter(dl_train))[0]

In [8]:
sample.keys()

dict_keys(['inputs', 'input_lengths', 'seq'])

In [9]:
Squeezeformer_RNA = importlib.import_module(cfg.model).Squeezeformer_RNA
model = Squeezeformer_RNA(cfg,infer_mode='True').to(cfg.device)

In [10]:
sample['input_lengths'].shape

torch.Size([32])

In [11]:
dict_to = importlib.import_module(cfg.dataset).dict_to

In [12]:
sample = dict_to(sample,cfg.device)

In [13]:
tensors = {'inputs':sample['inputs'],'input_lengths':sample['input_lengths'],'seq':sample['seq']}
summary(
    model,
    input_data=[tensors],
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=30,
    depth=10
)

Layer (type:depth-idx)                                                 Input Shape                    Output Shape                   Param #                        Trainable
Squeezeformer_RNA                                                      [32, 206]                      [32, 206, 2]                   --                             True
├─Embedding: 1-1                                                       [32, 206]                      [32, 206, 384]                 1,536                          True
├─SqueezeformerEncoder: 1-2                                            [32, 206, 384]                 [32, 206, 384]                 3,031,306                      True
│    └─ModuleList: 2-1                                                 --                             --                             --                             True
│    │    └─SqueezeformerBlock: 3-1                                    [32, 206, 384]                 [32, 206, 384]                 --               

In [14]:
y=model(sample);

In [21]:
y['fc_outputs'];

In [17]:
#!apt-get update
#!apt-get install graphviz
#!pip install torchviz

Hit:1 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:2 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:3 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Reading package lists... Done
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
graphviz is already the newest version (2.42.2-6).
0 upgraded, 0 newly installed, 0 to remove and 97 not upgraded.


In [20]:
#make_dot(y['fc_outputs'])
#make_dot(y['fc_outputs'], params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")