# 知乎千赞的pytorch模型训练代码模版支持Multi-GPU-DDP模式啦

一般pytorch需要用户自定义训练循环，可以说有1000个pytorch用户就有1000种训练代码风格。

从实用角度讲，一个优秀的训练循环应当具备以下特点。

代码简洁易懂 【模块化、易修改、short-enough】

支持常用功能 【进度条、评估指标、early-stopping】

经过反复斟酌测试，我精心设计了仿照keras风格的pytorch训练循环，完全满足以上条件。

该方案在知乎受到许多读者喜爱，目前为止获得了超过600个赞。

知乎完整回答链接：《深度学习里面，请问有写train函数的模板吗？》

https://www.zhihu.com/question/523869554/answer/2633479163




以上pytorch模型训练模版也是我开源的一个pytorch模型训练工具 torchkeras库的核心代码。

https://github.com/lyhue1991/torchkeras

铛铛铛铛，torchkeras加入新功能啦。

最近，通过引入HuggingFace的accelerate库的功能，torchkeras进一步支持了 多GPU的DDP模式和TPU设备上的模型训练。

这里给大家演示一下，非常强大和丝滑。

accelerate库的一个简要介绍，可以参考我在知乎的文章。

《20分钟吃掉accelerate模型加速工具😋》

https://zhuanlan.zhihu.com/p/599274899

In [1]:
#从git安装最新的accelerate仓库
!pip install git+https://github.com/huggingface/accelerate

Collecting git+https://github.com/huggingface/accelerate
  Cloning https://github.com/huggingface/accelerate to /tmp/pip-req-build-fwewjjgf
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/accelerate /tmp/pip-req-build-fwewjjgf
  Resolved https://github.com/huggingface/accelerate to commit b22f088ff662de748cf3f97c7ad8bf5a6dd6a7b9
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: accelerate
  Building wheel for accelerate (pyproject.toml) ... [?25ldone
[?25h  Created wheel for accelerate: filename=accelerate-0.15.0.dev0-py3-none-any.whl size=195428 sha256=41a490004fc65e286cb18d6896c6b2fc93129c85d8a100b3c4a3f0543ded6064
  Stored in directory: /tmp/pip-ephem-wheel-cache-u8mnojrw/wheels/81/c1/23/6068c1115888b4dd7da88f966c002c30840985c047f6cc1653
Successfully built accelerate
Installing collected pac

## 一，torchkeras源码解析

torchkeras的核心代码在 下面这个文件中。

https://github.com/lyhue1991/torchkeras/blob/master/torchkeras/kerasmodel.py



In [None]:
import sys,datetime
from tqdm import tqdm 
from copy import deepcopy
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator

def colorful(obj,color="red", display_type="plain"):
    color_dict = {"black":"30", "red":"31", "green":"32", "yellow":"33",
                    "blue":"34", "purple":"35","cyan":"36",  "white":"37"}
    display_type_dict = {"plain":"0","highlight":"1","underline":"4",
                "shine":"5","inverse":"7","invisible":"8"}
    s = str(obj)
    color_code = color_dict.get(color,"")
    display  = display_type_dict.get(display_type,"")
    out = '\033[{};{}m'.format(display,color_code)+s+'\033[0m'
    return out 

class StepRunner:
    def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
    
    def __call__(self, batch):
        features,labels = batch 
        
        #loss
        preds = self.net(features)
        loss = self.loss_fn(preds,labels)

        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
        all_preds = self.accelerator.gather(preds)
        all_labels = self.accelerator.gather(labels)
        all_loss = self.accelerator.gather(loss).sum()
            
        #metrics
        step_metrics = {self.stage+"_"+name:metric_fn(all_preds, all_labels).item() 
                        for name,metric_fn in self.metrics_dict.items()}
        
        return all_loss.item(),step_metrics

class EpochRunner:
    def __init__(self,steprunner):
        self.steprunner = steprunner
        self.stage = steprunner.stage
        self.steprunner.net.train() if self.stage=="train" else self.steprunner.net.eval()
        self.accelerator = self.steprunner.accelerator
        
    def __call__(self,dataloader):
        total_loss,step = 0,0
        loop = tqdm(enumerate(dataloader), 
                    total =len(dataloader),
                    file=sys.stdout,
                    disable=not self.accelerator.is_local_main_process,
                    ncols = 100
                   )
        
        for i, batch in loop: 
            if self.stage=="train":
                loss, step_metrics = self.steprunner(batch)
            else:
                with torch.no_grad():
                    loss, step_metrics = self.steprunner(batch)
                    
            step_log = dict({self.stage+"_loss":loss},**step_metrics)
            total_loss += loss
            step+=1
            
            if i!=len(dataloader)-1:
                loop.set_postfix(**step_log)
            else:
                epoch_loss = total_loss/step
                epoch_metrics = {self.stage+"_"+name:metric_fn.compute().item() 
                                 for name,metric_fn in self.steprunner.metrics_dict.items()}
                epoch_log = dict({self.stage+"_loss":epoch_loss},**epoch_metrics)
                loop.set_postfix(**epoch_log)
                for name,metric_fn in self.steprunner.metrics_dict.items():
                    metric_fn.reset()
        return epoch_log
    
class KerasModel(torch.nn.Module):
    def __init__(self,net,loss_fn,metrics_dict=None,optimizer=None,lr_scheduler = None):
        super().__init__()
        self.net,self.loss_fn = net, loss_fn
        self.metrics_dict = torch.nn.ModuleDict(metrics_dict) 
        self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(
            self.net.parameters(), lr=1e-3)
        self.lr_scheduler = lr_scheduler

    def forward(self, x):
        return self.net.forward(x)

    def fit(self, train_data, val_data=None, epochs=10,ckpt_path='checkpoint.pt',
            patience=5, monitor="val_loss", mode="min", mixed_precision='no'):
        
        accelerator = Accelerator(mixed_precision=mixed_precision)
        device = str(accelerator.device)
        device_type = '🐌'  if 'cpu' in device else '⚡️'
        accelerator.print(colorful("<<<<<< "+device_type +" "+ device +" is used >>>>>>"))
    
        net,optimizer,lr_scheduler= accelerator.prepare(
            self.net,self.optimizer,self.lr_scheduler)
        train_dataloader,val_dataloader = accelerator.prepare(train_data,val_data)
        
        loss_fn = self.loss_fn
        if isinstance(loss_fn,torch.nn.Module):
            loss_fn.to(accelerator.device)
        metrics_dict = self.metrics_dict 
        metrics_dict.to(accelerator.device)
        
        history = {}
        
        for epoch in range(1, epochs+1):

            nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            accelerator.print("\n"+"=========="*8 + "%s"%nowtime)
            accelerator.print("Epoch {0} / {1}".format(epoch, epochs)+"\n")

            # 1，train -------------------------------------------------  
            train_step_runner = StepRunner(
                    net = net,
                    loss_fn = loss_fn,
                    accelerator = accelerator,
                    stage="train",
                    metrics_dict=deepcopy(metrics_dict),
                    optimizer = optimizer,
                    lr_scheduler = lr_scheduler
            )

            train_epoch_runner = EpochRunner(train_step_runner)
            train_metrics = train_epoch_runner(train_dataloader)
            for name, metric in train_metrics.items():
                history[name] = history.get(name, []) + [metric]

            # 2，validate -------------------------------------------------
            if val_dataloader:
                val_step_runner = StepRunner(
                    net = net,
                    loss_fn = loss_fn,
                    accelerator = accelerator,
                    stage="val",
                    metrics_dict= deepcopy(metrics_dict)
                )
                val_epoch_runner = EpochRunner(val_step_runner)
                with torch.no_grad():
                    val_metrics = val_epoch_runner(val_dataloader)

                val_metrics["epoch"] = epoch
                for name, metric in val_metrics.items():
                    history[name] = history.get(name, []) + [metric]

            # 3，early-stopping -------------------------------------------------
            accelerator.wait_for_everyone()
            arr_scores = history[monitor]
            best_score_idx = np.argmax(arr_scores) if mode=="max" else np.argmin(arr_scores)

            if best_score_idx==len(arr_scores)-1:
                unwrapped_net = accelerator.unwrap_model(net)
                accelerator.save(unwrapped_net.state_dict(),ckpt_path)
                accelerator.print(colorful("<<<<<< reach best {0} : {1} >>>>>>".format(monitor,
                     arr_scores[best_score_idx])))

            if len(arr_scores)-best_score_idx>patience:
                accelerator.print(colorful("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(
                    monitor,patience)))
                break 
                
        if accelerator.is_local_main_process:
            self.net.load_state_dict(torch.load(ckpt_path))
            dfhistory = pd.DataFrame(history)
            accelerator.print(dfhistory)
            return dfhistory 
    
    @torch.no_grad()
    def evaluate(self, val_data):
        accelerator = Accelerator()
        self.net = accelerator.prepare(self.net)
        val_data = accelerator.prepare(val_data)
        if isinstance(self.loss_fn,torch.nn.Module):
            self.loss_fn.to(accelerator.device)
        self.metrics_dict.to(accelerator.device)
        
        val_step_runner = StepRunner(net = self.net,stage="val",
                    loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),
                    accelerator = accelerator)
        val_epoch_runner = EpochRunner(val_step_runner)
        val_metrics = val_epoch_runner(val_data)
        return val_metrics
    

以上该训练循环满足我所设想的全部特性。

模块化：自下而上分成 StepRunner, EpochRunner, 和KerasModel 三级，结构清晰明了。

易修改：如果输入和label形式有差异(例如，输入可能组装成字典，或者有多个输入)，仅需更改StepRunner就可以了，后面无需改动，非常灵活。

short-enough: 全部训练代码不到200行。

支持进度条：通过tqdm引入。

支持评估指标：可以引入torchmetrics库中的指标，也可以自定义评估指标。

支持early-stopping：在fit时候指定 monitor、mode、patience即可。



## 一，使用 CPU/单GPU 训练你的pytorch模型

当系统存在GPU时，torchkeras 会自动使用GPU训练你的pytorch模型，否则会使用CPU训练模型。

在我们的范例中，单GPU训练的话，一个Epoch大约是18s。

In [2]:
!pip install -U torchkeras 
!pip install -U torchmetrics 

Collecting torchkeras
  Downloading torchkeras-3.3.2-py3-none-any.whl (16 kB)
Installing collected packages: torchkeras
Successfully installed torchkeras-3.3.2
[0m

In [3]:
import torch
from torch import nn 
import torchvision 
from torchvision import transforms
import torchmetrics 
from torchkeras import KerasModel 

### 1，准备数据

def create_dataloaders(batch_size=1024):
    transform = transforms.Compose([transforms.ToTensor()])

    ds_train = torchvision.datasets.MNIST(root="./minist/",train=True,download=True,transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./minist/",train=False,download=True,transform=transform)

    dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True,
                                            num_workers=2,drop_last=True)
    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=False, 
                                          num_workers=2,drop_last=True)
    return dl_train,dl_val

dl_train,dl_val = create_dataloaders(batch_size=1024)

### 2，定义模型

def create_net():
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=512,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) 
    net.add_module("conv2",nn.Conv2d(in_channels=512,out_channels=256,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = 0.1))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(256,128))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(128,10))
    return net 

net = create_net() 


### 3，训练模型

loss_fn = nn.CrossEntropyLoss() 
metrics_dict = {'acc':torchmetrics.Accuracy(task='multiclass',num_classes=10)}

optimizer = torch.optim.AdamW(params=net.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer=optimizer,T_0=5)

model = KerasModel(net,loss_fn,metrics_dict,optimizer,lr_scheduler)
dfhistory = model.fit(train_data = dl_train,
    val_data = dl_val,
    epochs=5,
    ckpt_path='checkpoint.pt',
    patience=2,
    monitor='val_acc',
    mode='max',
    mixed_precision='no')

### 4，评估模型
model.net.load_state_dict(torch.load('checkpoint.pt'))
print(model.evaluate(dl_val)) 


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./minist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./minist/MNIST/raw/train-images-idx3-ubyte.gz to ./minist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./minist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./minist/MNIST/raw/train-labels-idx1-ubyte.gz to ./minist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./minist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./minist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./minist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./minist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./minist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./minist/MNIST/raw

[0;31m<<<<<< ⚡️ cuda is used >>>>>>[0m

Epoch 1 / 5

100%|█████████████████████████████| 58/58 [00:23<00:00,  2.46it/s, train_acc=0.636, train_loss=1.68]
100%|███████████████████████████████████| 9/9 [00:01<00:00,  7.15it/s, val_acc=0.872, val_loss=1.01]
[0;31m<<<<<< reach best val_acc : 0.8717448115348816 >>>>>>[0m

Epoch 2 / 5

100%|████████████████████████████| 58/58 [00:16<00:00,  3.62it/s, train_acc=0.868, train_loss=0.747]
100%|██████████████████████████████████| 9/9 [00:01<00:00,  6.81it/s, val_acc=0.938, val_loss=0.433]
[0;31m<<<<<< reach best val_acc : 0.9381510615348816 >>>>>>[0m

Epoch 3 / 5

100%|████████████████████████████| 58/58 [00:16<00:00,  3.57it/s, train_acc=0.918, train_loss=0.401]
100%|██████████████████████████████████| 9/9 [00:01<00:00,  7.04it/s, val_acc=0.951, val_loss=0.257]
[0;31m<<<<<< reach best val_acc : 0.9510633945465088 >>>>>>[0m

Epoch 4 / 5

100%|██████████████████

## 二，使用多GPU DDP模式训练你的pytorch模型

Kaggle中右边settings 中的 ACCELERATOR选择 GPU T4x2。

### 1，设置config 

In [None]:
import os
from accelerate.utils import write_basic_config
write_basic_config() # Write a config file
os._exit(0) # Restart the notebook to reload info from the latest config file 

In [None]:
# %load /root/.cache/huggingface/accelerate/default_config.yaml
{
  "compute_environment": "LOCAL_MACHINE",
  "deepspeed_config": {},
  "distributed_type": "MULTI_GPU",
  "downcast_bf16": false,
  "fsdp_config": {},
  "machine_rank": 0,
  "main_process_ip": null,
  "main_process_port": null,
  "main_training_function": "main",
  "mixed_precision": "no",
  "num_machines": 1,
  "num_processes": 2,
  "use_cpu": false
}


In [None]:
# or answer some question to create a config
#!accelerate config  

### 2，训练代码

在我们的范例中，双GPU使用DDP模式训练的话，一个Epoch大约是12s。

In [1]:
import torchvision 
from torchvision import transforms
from torch import nn 
import torch
import torchmetrics 
from accelerate import notebook_launcher
from torchkeras import KerasModel 

### 1，准备数据

def create_dataloaders(batch_size=1024):
    transform = transforms.Compose([transforms.ToTensor()])

    ds_train = torchvision.datasets.MNIST(root="./minist/",train=True,download=True,transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./minist/",train=False,download=True,transform=transform)

    dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True,
                                            num_workers=2,drop_last=True)
    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=False, 
                                          num_workers=2,drop_last=True)
    return dl_train,dl_val

dl_train,dl_val = create_dataloaders(batch_size=1024)

### 2，定义模型

def create_net():
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=512,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) 
    net.add_module("conv2",nn.Conv2d(in_channels=512,out_channels=256,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = 0.1))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(256,128))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(128,10))
    return net 

net = create_net() 


### 3，训练模型

loss_fn = nn.CrossEntropyLoss() 
metrics_dict = {'acc':torchmetrics.Accuracy(task='multiclass',num_classes=10)}

optimizer = torch.optim.AdamW(params=net.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer=optimizer,T_0=5)

model = KerasModel(net,loss_fn,metrics_dict,optimizer,lr_scheduler)

ckpt_path = 'checkpoint.pt'
args = dict(train_data = dl_train,
        val_data = dl_val,
        epochs=5,
        ckpt_path= ckpt_path,
        patience=2,
        monitor='val_acc',
        mode='max',
        mixed_precision='no').values()

notebook_launcher(model.fit, args, num_processes=2)

### 4，评估模型
model.net.load_state_dict(torch.load('checkpoint.pt'))
print(model.evaluate(dl_val)) 


Launching training on 2 GPUs.
[0;31m<<<<<< ⚡️ cuda:0 is used >>>>>>[0m

Epoch 1 / 5

100%|█████████████████████████████| 29/29 [00:13<00:00,  2.14it/s, train_acc=0.575, train_loss=3.96]
100%|███████████████████████████████████| 4/4 [00:01<00:00,  2.73it/s, val_acc=0.859, val_loss=3.14]
[0;31m<<<<<< reach best val_acc : 0.8587646484375 >>>>>>[0m

Epoch 2 / 5

100%|█████████████████████████████| 29/29 [00:09<00:00,  2.92it/s, train_acc=0.815, train_loss=2.52]
100%|████████████████████████████████████| 4/4 [00:01<00:00,  2.37it/s, val_acc=0.899, val_loss=1.8]
[0;31m<<<<<< reach best val_acc : 0.8985595703125 >>>>>>[0m

Epoch 3 / 5

100%|██████████████████████████████| 29/29 [00:10<00:00,  2.88it/s, train_acc=0.873, train_loss=1.5]
100%|███████████████████████████████████| 4/4 [00:01<00:00,  2.95it/s, val_acc=0.922, val_loss=1.06]
[0;31m<<<<<< reach best val_acc : 0.922119140625 >>>>>>[0m

Epoch 4 / 5

100%|████████████████████████████| 29/29 [00:10<00:00,  2.74it/s, train_acc=0.90

## 三，使用TPU加速你的pytorch模型

Kaggle中右边settings 中的 ACCELERATOR选择 TPU v3-8。

### 1，安装torch_xla

In [None]:
#安装torch_xla支持
!pip uninstall -y torch torch_xla 
!pip install torch==1.8.2+cpu -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

In [None]:
#从git安装最新的accelerate仓库
!pip install git+https://github.com/huggingface/accelerate

In [None]:
!pip install -U torchkeras 
!pip install -U torchmetrics 

In [None]:
#检查是否成功安装 torch_xla 
import torch_xla 

### 2，训练代码

torchmetrics库和TPU兼容性不太好，可以去掉metrics_dict进行训练。

In [1]:
import torch
from torch import nn 
import torchvision 
from torchvision import transforms
from accelerate import notebook_launcher

from torchkeras import KerasModel 

### 1，准备数据

def create_dataloaders(batch_size=1024):
    transform = transforms.Compose([transforms.ToTensor()])

    ds_train = torchvision.datasets.MNIST(root="./minist/",train=True,download=True,transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./minist/",train=False,download=True,transform=transform)

    dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True,
                                            num_workers=2,drop_last=True)
    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=False, 
                                          num_workers=2,drop_last=True)
    return dl_train,dl_val

dl_train,dl_val = create_dataloaders(batch_size=1024)

### 2，定义模型

def create_net():
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=512,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) 
    net.add_module("conv2",nn.Conv2d(in_channels=512,out_channels=256,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = 0.1))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(256,128))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(128,10))
    return net 

net = create_net() 

### 3，训练模型

loss_fn = nn.CrossEntropyLoss() 

optimizer = torch.optim.AdamW(params=net.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer=optimizer,T_0=5)

model = KerasModel(net,loss_fn,None,optimizer,lr_scheduler)

from accelerate import notebook_launcher

ckpt_path = 'checkpoint.pt'
args = dict(train_data = dl_train,
        val_data = dl_val,
        epochs=5,
        ckpt_path= ckpt_path,
        patience=2,
        monitor='val_loss',
        mode='min',
        mixed_precision='no').values()

notebook_launcher(model.fit, args, num_processes=8)


Launching a training on 8 TPU cores.


torchkeras.LightModel can't be used!


[0;31m<<<<<< ⚡️ xla:1 is used >>>>>>[0m

Epoch 1 / 5

100%|████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.75s/it, train_loss=17.8]
100%|████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.11s/it, val_loss=17]
[0;31m<<<<<< reach best val_loss : 17.017776489257812 >>>>>>[0m

Epoch 2 / 5

100%|████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.78s/it, train_loss=16.4]
100%|██████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.09s/it, val_loss=15.5]
[0;31m<<<<<< reach best val_loss : 15.495216369628906 >>>>>>[0m

Epoch 3 / 5

100%|████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.76s/it, train_loss=14.8]
100%|██████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.11s/it, val_loss=13.6]
[0;31m<<<<<< reach best val_loss : 13.585103988647461 >>>>>>[0m

Epoch 4 / 5

100%|████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.83s/it, train_loss=13.