In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torchvision.models as models
from torchsummary import summary
from config.root_path import DATA_ROOT

In [5]:
root_path = os.path.join(DATA_ROOT,'8_class_select')

data_transform = {
        "train": transforms.Compose([transforms.CenterCrop(256),
                                     transforms.Resize(152),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                    ]),
        "val": transforms.Compose([
                                   transforms.CenterCrop(256),
                                   transforms.Resize(152),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                   ])
                }


In [6]:
batch_size = 64
lr = 0.0001
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process\n'.format(nw))


train_dataset = datasets.ImageFolder(os.path.join(root_path,'train'),
                                     transform=data_transform['train'])

train_loader = DataLoader(train_dataset,
                          batch_size = batch_size,                         
                          shuffle=True,
                          num_workers = nw)

val_dataset = datasets.ImageFolder(os.path.join(root_path,'val'),
                                   transform=data_transform['val'])
val_loader = DataLoader(val_dataset,
                        batch_size = batch_size,
                        shuffle=False,
                        num_workers = nw)

Using 8 dataloader workers every process



In [7]:
device = torch.device('cuda:3')

In [3]:
net = models.maxvit_t()

In [4]:
net

MaxVit(
  (stem): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
    )
    (1): Conv2dNormActivation(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (blocks): ModuleList(
    (0): MaxVitBlock(
      (layers): ModuleList(
        (0): MaxVitLayer(
          (layers): Sequential(
            (MBconv): MBConv(
              (proj): Sequential(
                (0): AvgPool2d(kernel_size=3, stride=2, padding=1)
                (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
              )
              (stochastic_depth): Identity()
              (layers): Sequential(
                (pre_norm): BatchNorm2d(64, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
                (conv_a): Conv2dNormActivation(
           

In [6]:
net.classifier[5]=nn.Linear(512, 8)

In [7]:

net

MaxVit(
  (stem): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
    )
    (1): Conv2dNormActivation(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (blocks): ModuleList(
    (0): MaxVitBlock(
      (layers): ModuleList(
        (0): MaxVitLayer(
          (layers): Sequential(
            (MBconv): MBConv(
              (proj): Sequential(
                (0): AvgPool2d(kernel_size=3, stride=2, padding=1)
                (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
              )
              (stochastic_depth): Identity()
              (layers): Sequential(
                (pre_norm): BatchNorm2d(64, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
                (conv_a): Conv2dNormActivation(
           

In [19]:
net.classifier = nn.Sequential(
    dropout,
    nn.Linear(in_features, 8)
)

In [20]:
net

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  

In [None]:
net.to(device)
for param in net.parameters():
    print(param.device)
    break  # 只查看第一个参数的设备

In [None]:
loss_funtion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr=0.001)

In [None]:
for epoch in range(10):
    net.train()
    running_loss = 0.0
    for i,data in enumerate(train_loader):
        inputs,labels = data
        inputs,labels = inputs.to(device),labels.to(device)
        print(inputs.device,labels.device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = loss_funtion(outputs,labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print(running_loss)
        

