In [1]:
import torch
import torchvision
from torchvision import transforms

In [2]:
class Args:
    img_size=224
    train_batch_size=16
    test_batch_size=32
    
args=Args()

In [3]:
from torchvision import transforms
## 数据增强
transform_train = transforms.Compose([
        transforms.RandomResizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
transform_test = transforms.Compose([
    transforms.Resize((args.img_size, args.img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
transforms={
    'train':transform_train,
    'test':transform_test
}

In [4]:
## 加载数据
train_dataset=torchvision.datasets.CIFAR10(
    root='./',
    train=True,
    download=True,
    transform=transforms['train']
)
test_dataset=torchvision.datasets.CIFAR10(
    root='./',
    train=False,
    download=True,
    transform=transforms['test']
)

train_dataloader=torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=args.train_batch_size,
    pin_memory=True
)

test_dataloader=torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=args.test_batch_size,
    pin_memory=True
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 79862523.54it/s]


Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


## ViT实现

In [7]:
from torch import nn
def swish(x):
    return x*torch.sigmoid(x)
ACT2FN={
    "gelu":nn.functional.gelu,
    "relu":nn.functional.relu,
    "swish":swish
}


In [96]:
## embedding
class Embedding(nn.Module):
    def __init__(self, img_size=32,patch_size=4, in_c=3, embed_dim=768,dropout_rate=0.1):
        super().__init__()
        img_size=(img_size,img_size)
        patch_size=(patch_size,patch_size)
        self.grid_size=(img_size[0]//patch_size[0],img_size[1]//patch_size[1])
        self.num_patch=self.grid_size[0]*self.grid_size[1]
        self.proj=nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size)
        self.position_embeddings=nn.Parameter(torch.zeros(1,1,embed_dim))
        self.cls_token=nn.Parameter(torch.zeros(1,1,embed_dim))
        self.dropout=nn.Dropout(dropout_rate)
    def forward(self,x):
        B=x.shape[0]
        cls_tokens=self.cls_token.expand(B,-1,-1)
        x=self.proj(x)
        x=x.flatten(2)
        x=x.transpose(-1,-2)
        x=torch.cat((cls_tokens,x),dim=1)
        embeddings=x+self.position_embeddings
        embeddings=self.dropout(embeddings)
        return embeddings

In [97]:
class Attention(nn.Module):
    def __init__(self,dim,num_heads=12,qkv_bias=None,attn_drop_rate=0.,proj_drop_rate=0.):
        super().__init__()
        self.num_heads=num_heads
        self.head_dim=dim//num_heads
        new_dim=self.head_dim*self.num_heads
        self.new_dim=new_dim
        #self.scale=qk_scale or head_dim**(-0.5)
        self.q=nn.Linear(dim,new_dim)
        self.k=nn.Linear(dim,new_dim)
        self.v=nn.Linear(dim,new_dim)
        
        self.attn_drop=nn.Dropout(attn_drop_rate)
        self.o_proj=nn.Linear(new_dim,dim)
        self.proj_drop=nn.Dropout(proj_drop_rate)
        
        self.softmax=nn.Softmax(dim=-1)
    def transpose_for_scores(self,x):
        new_x_shape=x.shape[:-1]+(self.num_heads,self.head_dim)
        x=x.view(*new_x_shape)
        return x.permute(0,2,1,3) ##( B,head_nums,length,head_dim)
    def forward(self,x):
        B,N,C=x.shape
        q=self.q(x)
        k=self.k(x)
        v=self.v(x)
        
        q=self.transpose_for_scores(q)
        k=self.transpose_for_scores(k)
        v=self.transpose_for_scores(v)
        
        attention_scores=torch.matmul(q,k.transpose(-2,-1))
        attention_scores=attention_scores*(self.head_dim**(-0.5))
        attention_prods=self.softmax(attention_scores)## (B,head_nums,length,length)
        attention_prods=self.attn_drop(attention_prods)
        
        new_values=torch.matmul(attention_prods,v)## (B,head_nums,length,head_dim)
        new_values=new_values.permute(0,2,1,3).contiguous()
        new_shape=new_values.shape[:-2]+(self.new_dim,)
        new_values=new_values.view(*new_shape)## (B,length,new_dim)
        
        output=self.o_proj(new_values)
        output=self.proj_drop(output)
        return output
        
        
class MLP(nn.Module):
    def __init__(self,hidden_dim,mlp_dim,dropout_rate=0.1):
        super().__init__()
        self.fc1=nn.Linear(hidden_dim,mlp_dim)
        self.fc2=nn.Linear(mlp_dim,hidden_dim)
        self.act=ACT2FN['gelu']
        self.dropout=nn.Dropout(dropout_rate)
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self,x):
        x=self.fc1(x)
        x=self.act(x)
        x=self.dropout(x)
        x=self.fc2(x)
        x=self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self,hidden_dim):
        super().__init__()
        self.layer_norm_attn=nn.LayerNorm(hidden_dim,eps=1e-6)
        self.layer_norm_ffn=nn.LayerNorm(hidden_dim,eps=1e-6)
        self.ffn=MLP(hidden_dim,hidden_dim*4,0.2)
        self.attn=Attention(hidden_dim)
    def forward(self,x):
        h=x
        x=self.layer_norm_attn(x)
        x=self.attn(x)
        x=x+h
        
        h=x
        x=self.layer_norm_ffn(x)
        x=self.ffn(x)
        x=x+h
        return x
    
class Encoder(nn.Module):
    def __init__(self,hidden_dim,num_layer):
        super().__init__()
        self.embedding=Embedding()
        self.blocks=nn.ModuleList()
        self.layer_norm=nn.LayerNorm(hidden_dim,eps=1e-6)
        for i in range(num_layer):
            self.blocks.append(Block(hidden_dim))
    def forward(self,x):
        x=self.embedding(x)
        for layer in self.blocks:
            x=layer(x)
        x=self.layer_norm(x)
        return x

class ViT(nn.Module):
    def __init__(self,hidden_dim,num_layer,num_classes):
        super().__init__()
        self.encoder=Encoder(hidden_dim,num_layer)
        self.head=nn.Linear(hidden_dim,num_classes)
    def forward(self,x):
        x=self.encoder(x)
        logits=self.head(x[:,0])
        return logits
        
        

## 从零开始训练

In [122]:
from transformers import get_scheduler
hidden_dim=768
num_layer=12
num_classes=10
device=torch.device('cuda')
model=ViT(hidden_dim,num_layer,num_classes).to(device)
device_ids = [0, 1]
model= torch.nn.DataParallel(model, device_ids=device_ids)
optimizer=torch.optim.AdamW(model.parameters(),lr=8e-4)
#optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
num_warmup_steps=100
num_training_steps=2000
lr_scheduler = get_scheduler(
    name="cosine", optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
)

loss_fn = nn.CrossEntropyLoss()
print(f'参数量:{sum(m.numel() for m in model.parameters())/1e6}M')

In [88]:
## 验证
from tqdm import tqdm
@torch.no_grad()
def eval_model():
    model.eval()
    total=0
    acc=0
    for batch,labels in tqdm(test_dataloader):
        batch=batch.to(device)
        labels=labels.to(device)
        logits= model(batch).view(-1,num_classes)
        pred=torch.argmax(logits,dim=-1)
        labels=labels.view(-1)
        acc+=torch.sum(pred==labels)
        total+=pred.shape[0]
    print('acc:',acc/total)
eval_model()

100%|██████████| 157/157 [00:20<00:00,  7.51it/s]

acc: tensor(0.0142, device='cuda:0')





In [None]:
## 训练
from tqdm import tqdm
num_epochs=5
for epoch in range(num_epochs):
    model.train()
    for batch,labels in tqdm(train_dataloader):
        batch=batch.to(device)
        labels=labels.to(device)
        logits= model(batch)
        loss=loss_fn(logits.view(-1,num_classes),labels.view(-1))
        # 反向传播和参数更新
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
    eval_model()
    


## 加载预训练模型

In [28]:
model=torchvision.models.vit_b_16(pretrained=True)
num_classes=10
model.heads=torch.nn.Linear(768,num_classes)
device=torch.device('cuda')
model=model.to(device)

In [29]:
device=torch.device('cuda')
optimizer=torch.optim.AdamW(model.parameters(),lr=2e-5)
loss_fn = torch.nn.CrossEntropyLoss()
print(f'参数量:{sum(m.numel() for m in model.parameters())/1e6}M')

参数量:85.806346M


In [None]:
## 验证
from tqdm import tqdm
@torch.no_grad()
def eval_model():
    model.eval()
    total=0
    acc=0
    for batch,labels in tqdm(test_dataloader):
        batch=batch.to(device)
        labels=labels.to(device)
        logits= model(batch).view(-1,num_classes)
        pred=torch.argmax(logits,dim=-1)
        labels=labels.view(-1)
        acc+=torch.sum(pred==labels)
        total+=pred.shape[0]
    print('acc:',acc/total)
## eval_model()

In [30]:
## 训练
from tqdm import tqdm
num_epochs=1
for epoch in range(num_epochs):
    model.train()
    i=0
    for batch,labels in tqdm(train_dataloader):
        batch=batch.to(device)
        labels=labels.to(device)
        logits= model(batch)
        loss=loss_fn(logits.view(-1,num_classes),labels.view(-1))
        # 反向传播和参数更新
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        i+=1
        if i>100:
            break
    eval_model()
    


  3%|▎         | 100/3125 [01:00<30:41,  1.64it/s]
100%|██████████| 313/313 [01:56<00:00,  2.70it/s]

acc: tensor(0.9051, device='cuda:0')





In [21]:
all_models = torchvision.models.list_models()
all_models

['alexnet',
 'convnext_base',
 'convnext_large',
 'convnext_small',
 'convnext_tiny',
 'deeplabv3_mobilenet_v3_large',
 'deeplabv3_resnet101',
 'deeplabv3_resnet50',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b2',
 'efficientnet_b3',
 'efficientnet_b4',
 'efficientnet_b5',
 'efficientnet_b6',
 'efficientnet_b7',
 'efficientnet_v2_l',
 'efficientnet_v2_m',
 'efficientnet_v2_s',
 'fasterrcnn_mobilenet_v3_large_320_fpn',
 'fasterrcnn_mobilenet_v3_large_fpn',
 'fasterrcnn_resnet50_fpn',
 'fasterrcnn_resnet50_fpn_v2',
 'fcn_resnet101',
 'fcn_resnet50',
 'fcos_resnet50_fpn',
 'googlenet',
 'inception_v3',
 'keypointrcnn_resnet50_fpn',
 'lraspp_mobilenet_v3_large',
 'maskrcnn_resnet50_fpn',
 'maskrcnn_resnet50_fpn_v2',
 'maxvit_t',
 'mc3_18',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'mvit_v1_b',
 'mvit_v2_s',
 'quantized_googlenet',
 '