In [1]:
import peft
from peft import LoraConfig, get_peft_model
import os
import torch
import numpy as np
from tqdm import tqdm
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torchvision.transforms import Resize
from torchvision.models import resnet152

[2024-03-23 13:50:02,109] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

In [3]:
DEVICE = "cuda:0"
normalize = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)
train_set = torchvision.datasets.CIFAR10(
    root="./data", 
    train=True, 
#     download=True, 
    transform=normalize
)
train_loader = DataLoader(train_set, batch_size=128, shuffle=False, num_workers=2)

In [5]:
test_set = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False, 
#     download=True, 
    transform=normalize)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=128, shuffle=False, num_workers=2)

In [4]:
criterion = nn.CrossEntropyLoss()

In [6]:
def train(net, train_loader, lr=1e-3, epochs=20):
    trainable_para = []
    for p in net.parameters():
        if p.requires_grad:
            trainable_para.append(p)
            
    print("num of trainable parameters: ", sum(p.numel() for p in trainable_para if p.requires_grad))
    optimizer = optim.Adam(trainable_para, lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    net.train()
    
    for epoch in range(epochs):
        net.train()
        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()
        if (epoch+1) % 2 == 0:
            test(net)
    net.eval()

In [7]:
def test(net):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    print( 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/len(test_loader), 100.*correct/total, correct, total))

## whole train

In [8]:
model = resnet152(weights='DEFAULT')
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 10)
model.to(DEVICE)
train_loader = DataLoader(train_set, batch_size=128, shuffle=False, num_workers=2)
train(model, train_loader, lr=3e-4)

num of trainable parameters:  58164298


 18%|█▊        | 71/391 [00:30<02:16,  2.34it/s]


KeyboardInterrupt: 

## lora

In [None]:
target_modules = []
available_types = [torch.nn.modules.conv.Conv2d, torch.nn.modules.linear.Linear]
for n, m in model.named_modules():
    if type(m) in available_types:
        target_modules.append(n)
target_modules.remove('fc')

config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",# 'none', 'all' or 'lora_only'
    target_modules=target_modules,
    modules_to_save=["fc"],
)
peft_model = get_peft_model(model, config).to(DEVICE)
peft_model.print_trainable_parameters()

In [None]:
train_loader = DataLoader(train_set, batch_size=1024, shuffle=False, num_workers=2)
model = resnet152(weights=None, num_classes=10)
model.to(DEVICE)
peft_model = get_peft_model(model, config).to(DEVICE)
peft_model.print_trainable_parameters()
train(peft_model, train_loader)

In [None]:
train_loader = DataLoader(train_set, batch_size=2048, shuffle=False, num_workers=2)
model = resnet152(weights=None, num_classes=10)
model.to(DEVICE)
peft_model = get_peft_model(model, config).to(DEVICE)
peft_model.print_trainable_parameters()
train(peft_model, train_loader)

In [None]:
test(peft_model)