In [None]:
import torch
import os

import wandb
import time
import torch.nn as nn
import numpy as np
import pandas as pd
from utils.logger import *


## 设置GPU可见设备
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
print("Available GPUs:", torch.cuda.device_count())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 检查当前选择的 GPU
print("Current device index:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

In [4]:

def train(train_dataloader, val_dataloader, model, optimizer, scheduler, config, device):
    start_epoch = 1

    best_loss = 1e10
    train_loss_array = np.array([])
    val_loss_array = np.array([])

    best_model = None
    best_model_path = os.path.join(config.model_dir, config.model_name)

    model.zero_grad()
    start_time = time.time()
    for epoch in range(start_epoch,config.max_epoch+1):
        model.train()
        running_train_loss = 0.0
        epoch_start_time = time.time()
        num_iter = 0

        for idx, (_,data, _, age, _) in enumerate(train_dataloader):
            batch_start_time = time.time()
            num_iter += 1
            
            points = data.to(device)
            label = age.to(device).float()
            
            ret = model(points)
            loss = model.module.get_loss_acc(ret, label,config.model.loss)

            loss.backward()

            # forward
            if num_iter == config.step_per_update:
                if config.get('grad_norm_clip') is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip, norm_type=2)
                num_iter = 0
                optimizer.step()
                model.zero_grad()

            batch_loss = loss.item()
            running_train_loss += batch_loss
            batch_duration = time.time() - batch_start_time

            wandb.log({
                "batch_duration": batch_duration,
                "batch_train_loss": batch_loss,
            })
        if isinstance(scheduler, list):
            for item in scheduler:
                item.step(epoch)
        else:
            scheduler.step(epoch)

        epoch_train_loss = running_train_loss / len(train_dataloader)
        

        ## 评估模型在验证集上的表现
        val_start_time = time.time()
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for idx, (_,data, _,age, _) in enumerate(val_dataloader):
                points = data.to(device)
                label = age.to(device).float()
                ret = model(points)
                loss = model.module.get_loss_acc(ret, label,config.model.loss)
                running_val_loss += loss.item()

        epoch_val_loss = running_val_loss / len(val_dataloader)
        
        train_loss_array = np.append(train_loss_array, epoch_train_loss)
        val_loss_array = np.append(val_loss_array, epoch_val_loss)
        val_duration = time.time() - val_start_time
        epoch_duration = time.time() - epoch_start_time

        wandb.log({
            "epoch":epoch,
            "val_duration": val_duration,
            "epoch_train_loss": epoch_train_loss,
            "epoch_val_loss": epoch_val_loss,
            "epoch_duration": epoch_duration
        })
        print(f"Epoch {epoch}/{config.max_epoch} finished. Train Average Loss: {epoch_train_loss:.4f}, Validation Average Loss :{epoch_val_loss:.4f}, Duration: {epoch_duration:.2f}s")
        if epoch_val_loss < best_loss:
            best_loss = epoch_val_loss
            best_model = model.module.state_dict()
            torch.save(best_model, best_model_path)
            print('Save the best model to %s' % best_model_path) 
     
    return train_loss_array, val_loss_array



In [5]:
## 训练模型 
import yaml
import pickle
from easydict import EasyDict as edict
from models.Point_MAE import PointTransformer



In [None]:
## 读取配置文件
with open('finetune_age.yaml', 'r') as f:
    config = edict(yaml.safe_load(f))


In [None]:
from utils import misc
from datasets.finetune.dataset import dataset_builder
## 设置随机种子
misc.set_random_seed(42)

## 读取数据集
train_dataloader = dataset_builder(config.dataset.train)
val_dataloader = dataset_builder(config.dataset.val)
test_dataloader = dataset_builder(config.dataset.test)


In [None]:
## 构建模型
from models.build import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PointTransformer(config.model)
model.load_model_from_ckpt(config.ckpt_file)
model = torch.nn.DataParallel(model, device_ids=[0, 1])
model = model.to(device)
optimizer, scheduler = build_opti_sche(model, config)


In [None]:

## 训练
start_time = time.time()
wandb.init(
    project="pointmae-finetune",
    config=config,
    name="age"
)
train_loss_array, val_loss_array = train(train_dataloader, val_dataloader, model, optimizer, scheduler, config, device)
wandb.finish()  
end_time = time.time()
run_time = end_time - start_time
print(f"Total training time: {run_time:.2f} seconds")