In [None]:
import torch
import os

import wandb
import torch.nn as nn
import time
import numpy as np
import pandas as pd


## 设置GPU可见设备
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
print("Available GPUs:", torch.cuda.device_count())

# 检查当前选择的 GPU
print("Current device index:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

In [4]:

def train(train_dataloader, val_dataloader, model, optimizer, scheduler, config, device):
    # model.to(device)
    start_epoch = 1

    best_loss = 1e10
    train_loss_array = np.array([])
    val_loss_array = np.array([])

    best_model = None
    best_model_path = os.path.join(config.model_dir, config.model_name)
    model.zero_grad()

    start_time = time.time()
    for epoch in range(start_epoch,config.max_epoch+1):
        model.train()
        running_train_loss = 0.0
        epoch_start_time = time.time()

        num_iter = 0

        for idx, (_,data) in enumerate(train_dataloader):
            batch_start_time = time.time()
            num_iter += 1
            data = data.to(device)
            points = data
            _,_,_,rebulid_points,gt_points =  model(points)
            loss = model.module.get_loss(rebulid_points,gt_points)
            loss.backward()

            if num_iter == config.step_per_update:
                num_iter = 0
                optimizer.step()
                model.zero_grad()
            
            batch_loss = loss.item()
            running_train_loss += batch_loss
            batch_duration = time.time()-batch_start_time

            wandb.log({
                "batch_duration": batch_duration,
                "batch_train_loss": batch_loss,
            })

        if isinstance(scheduler, list):
            for item in scheduler:
                item.step(epoch)
        else:
            scheduler.step(epoch)

        epoch_train_loss = running_train_loss / len(train_dataloader)

        ## 评估模型在验证集上的表现
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for idx, (_,data) in enumerate(val_dataloader):
                points = data.to(device)
                _,_,_,rebulid_points,gt_points =  model(points)
                loss = model.module.get_loss(rebulid_points,gt_points)
                running_val_loss += loss.item()

        epoch_val_loss = running_val_loss / len(val_dataloader)
        
        train_loss_array = np.append(train_loss_array, epoch_train_loss)
        val_loss_array = np.append(val_loss_array, epoch_val_loss)
        epoch_duration = time.time() - epoch_start_time
        wandb.log({
            "epoch":epoch,
            "epoch_train_loss": epoch_train_loss,
            "epoch_val_loss": epoch_val_loss,
            "epoch_duration": epoch_duration
        })

        print(f"Epoch {epoch}/{config.max_epoch} finished. Train Average Loss: {epoch_train_loss:.4f}, Validation Average Loss :{epoch_val_loss:.4f}, Duration: {epoch_duration:.2f}s")
        if epoch_val_loss < best_loss:
            best_loss = epoch_val_loss
            best_model = model.module.state_dict()
            torch.save(best_model, best_model_path)
            print('Save the best model to %s' % best_model_path) 
        
    return train_loss_array, val_loss_array
 

In [5]:
def test(test_dataloader,model,config, device):
    model.eval()
    running_test_loss = 0.0
    with torch.no_grad():
        for idx,(_ ,data) in enumerate(test_dataloader):
            points = data.to(device)
            loss = model(points)
            running_test_loss += loss.item()

    test_loss = running_test_loss / len(test_dataloader)
    print(f"Test Average Loss :{test_loss:.4f}")
    return test_loss

In [None]:
## 训练模型 
import yaml
from easydict import EasyDict as edict
from models.Point_MAE import Point_MAE


## 读取配置文件
with open('pretrain.yaml', 'r') as f:
    config = edict(yaml.safe_load(f))



In [None]:
from utils import misc
from datasets.pretrain.dataset import *
from models.build import *
## 设置随机种子
misc.set_random_seed(42)

## 读取数据集
train_dataloader = dataset_builder(config.dataset.train)
val_dataloader = dataset_builder(config.dataset.val)
test_dataloader = dataset_builder(config.dataset.test)


In [None]:

## 设置模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Point_MAE(config.model)

In [10]:
model = torch.nn.DataParallel(model, device_ids=[0, 1])
model = model.to(device)
optimizer, scheduler = build_opti_sche(model, config)

In [None]:
## 训练
start_time = time.time()
wandb.init(
    project="pointmae-pretrain",
    config=config,
    name="pretrain"
)
train_loss_array, val_loss_array = train(train_dataloader, val_dataloader, model, optimizer, scheduler, config, device)
wandb.finish()  
end_time = time.time()
run_time = end_time - start_time
print(f"Total training time: {run_time:.2f} seconds")
