In [1]:
%%capture
!pip install ipywidgets

In [2]:
# ! wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
# ! wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar

In [2]:
import torch

from torchvision import datasets

In [3]:
dataset_train = datasets.ImageNet('/notebooks/imagenet/')

In [4]:
dataset_train

Dataset ImageNet
    Number of datapoints: 1281167
    Root location: /notebooks/imagenet/
    Split: train

In [5]:
%%capture
! pip install git+https://github.com/keepsimpler/sunyata
! pip install pytorch-lightning
! pip install pytorch-lightning-bolts
! pip install einops

In [9]:
! pip install timm

Collecting timm
  Downloading timm-0.6.11-py3-none-any.whl (548 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m548.7/548.7 kB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.6.11
[0m

In [6]:
import torch
import torchvision
torch.__version__, torchvision.__version__

('1.12.0+cu116', '0.13.0+cu116')

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor

from sunyata.pytorch.data.tiny_imagenet import TinyImageNet, TinyImageNetDataModule

from sunyata.pytorch.arch.base import BaseModule, Residual


In [36]:
from sunyata.pytorch.arch.convnext2 import ConvNext, ConvNextCfg, convnext_tiny

In [None]:
cfg = ConvNextCfg(
    drop_path = 0.1,
    model_ema = True,
    model_ema_eval = True,

    num_workers = 8,
)
cfg

In [46]:
from timm.data.constants import \
    IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.data import create_transform
from timm.data.mixup import Mixup
from timm.models import create_model
from timm.models.registry import register_model

In [19]:
mean = IMAGENET_INCEPTION_MEAN if not cfg.imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
std = IMAGENET_INCEPTION_STD if not cfg.imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD

In [20]:
transform = create_transform(
    input_size = cfg.input_size,
    is_training = True,
    color_jitter = cfg.color_jitter,
    auto_augment = cfg.aa,
    interpolation = cfg.train_interpolation,
    re_prob = cfg.reprob,
    re_mode = cfg.remode,
    re_count = cfg.recount,
    mean = mean,
    std = std,
)

In [21]:
resize_im = cfg.input_size > 32
if not resize_im:
    transform.transforms[0] = transforms.RandomCrop(
        cfg.input_size, padding=4
    )

In [25]:
dataset_train = datasets.ImageNet('/notebooks/imagenet/', transform=transform)

In [26]:
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

In [27]:
dataset_val = datasets.ImageNet('/notebooks/imagenet/', split='val', transform=val_transform)

In [29]:
data_loader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size = cfg.batch_size,
    num_workers = cfg.num_workers,
    pin_memory = cfg.pin_mem,
    drop_last = True,
)

In [30]:
data_loader_val = torch.utils.data.DataLoader(
    dataset_val,
    batch_size = int(1.5 * cfg.batch_size),
    num_workers = cfg.num_workers,
    pin_memory = cfg.pin_mem,
    drop_last = False,
)

In [33]:
mixup_active = cfg.mixup > 0 or cfg.cutmix > 0. or cfg.cutmix_minmax is not None
mixup_active

True

In [34]:
mixup_fn = Mixup(
    mixup_alpha = cfg.mixup, cutmix_alpha = cfg.cutmix,
    cutmix_minmax = cfg.cutmix_minmax,
    prob = cfg.mixup_prob,
    switch_prob = cfg.mixup_switch_prob,
    mode = cfg.mixup_mode,
    label_smoothing = cfg.smoothing,
    num_classes = cfg.nb_classes
)

In [39]:
cfg.pretrained = False

In [47]:
@register_model
def convnext_tiny(pretrained=False, pretrained_cfg=None, **kwargs):
    model = ConvNext(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
    if pretrained:
        raise NotImplementedError
    return model


In [48]:
model = create_model(
    'convnext_tiny',
    pretrained=False, 
    pretrained_cfg=None,
    num_classes = cfg.nb_classes,
    drop_path_rate = cfg.drop_path,
    layer_scale_init_value = cfg.layer_scale_init_value,
    head_init_scale = cfg.head_init_scale,
)

In [8]:
model = ResNext50(cfg)

pl.utilities.model_summary.summarize(model,max_depth=1)





  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 23.4 M
---------------------------------
23.4 M    Trainable params
0         Non-trainable params
23.4 M    Total params
93.559    Total estimated model params size (MB)

In [9]:
trainer = pl.Trainer(
    # progress_bar_refresh_rate=10,
    precision=16,
    max_epochs=cfg.num_epochs,
    accelerator='gpu',
    devices=1,
    enable_checkpointing=False,
    logger=pl_loggers.CSVLogger("lightning_logs/", name="convmixer"),
#     callbacks=[LearningRateMonitor(logging_interval="step")],
)    


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:

trainer.fit(model, tiny_image_net_datamodule)


can not write to csv file.
can not write to csv file.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  stdout_func(
  stdout_func(

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 23.4 M
---------------------------------
23.4 M    Trainable params
0         Non-trainable params
23.4 M    Total params
46.779    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]