In [1]:
import os
os.chdir("/home/v-runmao/projects/R-Drop/vit_src/data/")

In [2]:
import timm
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from tqdm.auto import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch

In [3]:
data = CIFAR100(root="./", train=True)

In [4]:
model = timm.create_model("efficientnet_b4", drop_rate=0.4, drop_path_rate=0.2, num_classes=1000).to("cuda:0")
print(model)

EfficientNet(
  (conv_stem): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): SiLU(inplace=True)
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
        (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SiLU(inplace=True)
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(48, 12, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(12, 48, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(48, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Identity()
 

In [6]:
data_cfg = resolve_data_config({}, model=model)
print(data_cfg)

{'input_size': (3, 320, 320), 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'crop_pct': 1.0}


In [7]:
tfm = create_transform(is_training=True, **data_cfg)
print(tfm)

Compose(
    RandomResizedCropAndInterpolation(size=(320, 320), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=PIL.Image.BICUBIC)
    RandomHorizontalFlip(p=0.5)
    ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.6, 1.4], hue=None)
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)


In [12]:
data.transform = tfm
loader = DataLoader(data, batch_size=256, shuffle=True, pin_memory=True, drop_last=True)

In [15]:
losses = []
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.05)
scaler = amp.GradScaler(enabled=True)

for i, (x, y) in enumerate(tqdm(loader)):
    x, y = x.to("cuda:0"), y.to("cuda:0")
    
    with amp.autocast(True):
        logits = model(x)
        loss = criterion(logits, y)
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    losses.append(loss.item())
    if i > 20:
        break

  0%|          | 0/195 [00:00<?, ?it/s]

In [16]:
losses

[4.566964626312256,
 4.540361404418945,
 4.48688268661499,
 4.528109550476074,
 4.523447036743164,
 4.523294448852539,
 4.569365978240967,
 4.482911109924316,
 4.486227989196777,
 4.529461860656738,
 4.4272379875183105,
 4.531652450561523,
 4.51021671295166,
 4.53688907623291,
 4.466104030609131,
 4.508081912994385,
 4.560098171234131,
 4.469296932220459,
 4.580643653869629,
 4.533179759979248,
 4.468344688415527,
 4.489779949188232]

In [13]:
losses = []
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
    for i, (x, y) in enumerate(tqdm(loader)):
        x, y = x.to("cuda:0"), y.to("cuda:0")

        logits = model(x)
        loss = criterion(logits, y)
        losses.append(loss.item())
        if i > 50:
            break

  0%|          | 0/62 [00:00<?, ?it/s]

Bad pipe message: %s [b'\xa7\x07\xfc:\xde\xa0\xe7\xd1\xf3\xde[\x80S9\xe0\xaf\xca\xf0 \xb7\x8e&!\xa1\x01\x9e\x85\xbf)\xed95\r;\xd9#\xfd\xac\xa0\x99\xf0\xe3V\xb5\xd2\xfd\x1bi\x88\xd3\xd7\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 m\xa8\x0b\xc5']
Bad pipe message: %s [b'0\x94\x86V\xe2']
Bad pipe message: %s [b'\xe9\x16E\x1d`<-C6\xdc\x83&']
Bad pipe message: %s [b"\x00}\x02`\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0s\xc0w\x00\xc4