In [13]:
from datasets import load_dataset
from torchvision.models import resnet18
import torch
import torch.optim as optim

from torchvision import transforms
from torch.utils.data import DataLoader
from copy import deepcopy
import torch.nn as nn
from conv_gemm.base_torch.gemm2col import Gem2ColConv2d 

from tqdm import tqdm
from PIL import Image


In [2]:
ds = load_dataset("zh-plus/tiny-imagenet")


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
base = resnet18(num_classes=200).to(device)
model_base  = deepcopy(base) 
model_gemm = deepcopy(base)

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
def replace_conv2d_with_gem2col(module: nn.Module):
    for name, child in list(module.named_children()):
        if isinstance(child, nn.Conv2d) and child.groups == 1:
            new = Gem2ColConv2d(
                in_channels=child.in_channels,
                out_channels=child.out_channels,
                kernel_size=child.kernel_size,
                stride=child.stride,
                padding=child.padding,
                dilation=child.dilation,
                bias=(child.bias is not None),
            ).to(next(module.parameters(), torch.tensor(0, device=device)).device)
            with torch.no_grad():
                new.weight.copy_(child.weight)         # <-- ключевой момент
                if child.bias is not None and new.bias is not None:
                    new.bias.copy_(child.bias)
            setattr(module, name, new)
        else:
            replace_conv2d_with_gem2col(child)

In [5]:
replace_conv2d_with_gem2col(model_gemm)
model_base.eval()
model_gemm.eval()

ResNet(
  (conv1): Gem2ColConv2d()
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Gem2ColConv2d()
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Gem2ColConv2d()
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Gem2ColConv2d()
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Gem2ColConv2d()
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Gem2ColConv2d()
      (bn1): BatchNorm2d(128, eps=1e-05

In [8]:
def to_rgb(img):
    # HF Datasets обычно отдают PIL.Image, но подстрахуемся
    if isinstance(img, Image.Image):
        return img.convert("RGB")
    if torch.is_tensor(img):
        # tensor [C,H,W]
        if img.ndim == 3 and img.size(0) == 1:
            return img.expand(3, -1, -1)  # уже тензор -> вернём тензор
        return img
    # numpy -> PIL -> RGB
    return Image.fromarray(img).convert("RGB")

In [9]:
train_tfms = transforms.Compose([
    transforms.Lambda(to_rgb),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # теперь гарантированно [3,H,W]
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

valid_tfms = transforms.Compose([
    transforms.Lambda(to_rgb),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),  # [3,H,W]
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [10]:
def make_collate(tfms):
    def collate(batch):
        imgs = [tfms(sample["image"]) for sample in batch]
        labels = torch.tensor([int(sample["label"]) for sample in batch], dtype=torch.long)
        return torch.stack(imgs, 0), labels
    return collate

train_loader = DataLoader(
    ds["train"], batch_size=64, shuffle=True,
    num_workers=4, pin_memory=True,
    collate_fn=make_collate(train_tfms),
)
valid_loader = DataLoader(
    ds["valid"], batch_size=128, shuffle=False,
    num_workers=4, pin_memory=True,
    collate_fn=make_collate(valid_tfms),
)

# Трейн

In [14]:
model = resnet18(weights=None)
num_classes = 200
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [15]:
for epoch in range(5):
    model.train()
    running_loss = 0.0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    print(f"Train loss: {running_loss / len(train_loader.dataset):.4f}")

    # === Валидация ===
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in valid_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.numel()
    print(f"Val acc: {correct / total * 100:.2f}%")

Epoch 1:   0%|▍                                                                                                                                                                        | 4/1563 [00:12<1:22:21,  3.17s/it]


KeyboardInterrupt: 