In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torchvision.models as models
from torchsummary import summary
from config.root_path import DATA_ROOT
from urllib.request import urlopen
from PIL import Image
import timm

In [5]:
root_path = os.path.join(DATA_ROOT,'8_class_select')

data_transform = {
        "train": transforms.Compose([transforms.CenterCrop(256),
                                     transforms.Resize(152),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                    ]),
        "val": transforms.Compose([
                                   transforms.CenterCrop(256),
                                   transforms.Resize(152),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                   ])
                }


In [6]:
batch_size = 64
lr = 0.0001
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process\n'.format(nw))


train_dataset = datasets.ImageFolder(os.path.join(root_path,'train'),
                                     transform=data_transform['train'])

train_loader = DataLoader(train_dataset,
                          batch_size = batch_size,                         
                          shuffle=True,
                          num_workers = nw)

val_dataset = datasets.ImageFolder(os.path.join(root_path,'val'),
                                   transform=data_transform['val'])
val_loader = DataLoader(val_dataset,
                        batch_size = batch_size,
                        shuffle=False,
                        num_workers = nw)

Using 8 dataloader workers every process



In [7]:
device = torch.device('cuda:3')

In [11]:
net = models.wide_resnet50_2(weights='IMAGENET1K_V2')

Downloading: "https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth" to /home/wangcheng/.cache/torch/hub/checkpoints/wide_resnet50_2-9ba9bcbe.pth
100%|██████████| 263M/263M [00:22<00:00, 12.4MB/s] 


In [7]:
net

RegNet(
  (stem): SimpleStemIN(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (trunk_output): Sequential(
    (block1): AnyStage(
      (block1-0): ResBottleneckBlock(
        (proj): Conv2dNormActivation(
          (0): Conv2d(32, 224, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (f): BottleneckTransform(
          (a): Conv2dNormActivation(
            (0): Conv2d(32, 224, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (b): Conv2dNormActivation(
            (0): Conv2d(224, 224, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=2, bias=False)
      

In [21]:
net.classifier[2]=nn.Linear(1024,8)

In [22]:

net

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (1): Permute()
          (2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=128, out_features=512, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=512, out_features=128, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (1): Permute()
          (2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (3): Linear(

In [19]:
net.classifier = nn.Sequential(
    dropout,
    nn.Linear(in_features, 8)
)

In [20]:
net

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  

In [None]:
net.to(device)
for param in net.parameters():
    print(param.device)
    break  # 只查看第一个参数的设备

In [None]:
loss_funtion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr=0.001)

In [None]:
for epoch in range(10):
    net.train()
    running_loss = 0.0
    for i,data in enumerate(train_loader):
        inputs,labels = data
        inputs,labels = inputs.to(device),labels.to(device)
        print(inputs.device,labels.device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = loss_funtion(outputs,labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print(running_loss)
        

