forked from iMoonLab/HGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
131 lines (105 loc) · 4.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import time
import copy
import torch
import torch.optim as optim
import pprint as pp
import utils.hypergraph_utils as hgut
from models import HGNN
from config import get_config
from datasets import load_feature_construct_H
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cfg = get_config('config/config.yaml')
# initialize data
data_dir = cfg['modelnet40_ft'] if cfg['on_dataset'] == 'ModelNet40' \
else cfg['ntu2012_ft']
fts, lbls, idx_train, idx_test, H = \
load_feature_construct_H(data_dir,
m_prob=cfg['m_prob'],
K_neigs=cfg['K_neigs'],
is_probH=cfg['is_probH'],
use_mvcnn_feature=cfg['use_mvcnn_feature'],
use_gvcnn_feature=cfg['use_gvcnn_feature'],
use_mvcnn_feature_for_structure=cfg['use_mvcnn_feature_for_structure'],
use_gvcnn_feature_for_structure=cfg['use_gvcnn_feature_for_structure'])
G = hgut.generate_G_from_H(H)
n_class = int(lbls.max()) + 1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# transform data to device
fts = torch.Tensor(fts).to(device)
lbls = torch.Tensor(lbls).squeeze().long().to(device)
G = torch.Tensor(G).to(device)
idx_train = torch.Tensor(idx_train).long().to(device)
idx_test = torch.Tensor(idx_test).long().to(device)
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, print_freq=500):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
if epoch % print_freq == 0:
print('-' * 10)
print(f'Epoch {epoch}/{num_epochs - 1}')
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
idx = idx_train if phase == 'train' else idx_test
# Iterate over data.
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(fts, G)
loss = criterion(outputs[idx], lbls[idx])
_, preds = torch.max(outputs, 1)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * fts.size(0)
running_corrects += torch.sum(preds[idx] == lbls.data[idx])
epoch_loss = running_loss / len(idx)
epoch_acc = running_corrects.double() / len(idx)
if epoch % print_freq == 0:
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
if epoch % print_freq == 0:
print(f'Best val Acc: {best_acc:4f}')
print('-' * 20)
time_elapsed = time.time() - since
print(f'\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:4f}')
# load best model weights
model.load_state_dict(best_model_wts)
return model
def _main():
print(f"Classification on {cfg['on_dataset']} dataset!!! class number: {n_class}")
print(f"use MVCNN feature: {cfg['use_mvcnn_feature']}")
print(f"use GVCNN feature: {cfg['use_gvcnn_feature']}")
print(f"use MVCNN feature for structure: {cfg['use_mvcnn_feature_for_structure']}")
print(f"use GVCNN feature for structure: {cfg['use_gvcnn_feature_for_structure']}")
print('Configuration -> Start')
pp.pprint(cfg)
print('Configuration -> End')
model_ft = HGNN(in_ch=fts.shape[1],
n_class=n_class,
n_hid=cfg['n_hid'],
dropout=cfg['drop_out'])
model_ft = model_ft.to(device)
optimizer = optim.Adam(model_ft.parameters(), lr=cfg['lr'],
weight_decay=cfg['weight_decay'])
# optimizer = optim.SGD(model_ft.parameters(), lr=0.01, weight_decay=cfg['weight_decay)
schedular = optim.lr_scheduler.MultiStepLR(optimizer,
milestones=cfg['milestones'],
gamma=cfg['gamma'])
criterion = torch.nn.CrossEntropyLoss()
model_ft = train_model(model_ft, criterion, optimizer, schedular, cfg['max_epoch'], print_freq=cfg['print_freq'])
if __name__ == '__main__':
_main()