-
Notifications
You must be signed in to change notification settings - Fork 0
/
alexnet_1d_ofrecord.py
118 lines (92 loc) · 3.12 KB
/
alexnet_1d_ofrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import oneflow as flow
import oneflow.nn as nn
import flowvision
import flowvision.transforms as transforms
from ofrecord_data_utils import OFRecordDataLoader
import numpy as np
import time
BATCH_SIZE= 128
EPOCH_NUM = 3
PLACEMENT = flow.placement("cuda", [0,1])
S0 = flow.sbp.split(0)
B = flow.sbp.broadcast
DEVICE = "cuda" if flow.cuda.is_available() else "cpu"
print("Using {} device".format(DEVICE))
train_dataloader = OFRecordDataLoader(
ofrecord_root="/dataset/imagenette/ofrecord",
mode="train",
dataset_size=4000,
batch_size=BATCH_SIZE,
)
val_data_loader = OFRecordDataLoader(
ofrecord_root="/dataset/imagenette/ofrecord",
mode="val",
dataset_size=400,
batch_size=BATCH_SIZE,
)
model = flowvision.models.alexnet(pretrained=False, progress=True).to(DEVICE)
print(model)
loss_fn = nn.CrossEntropyLoss().to(DEVICE)
optimizer = flow.optim.SGD(model.parameters(), lr=1e-3)
class AlexNetGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.train_data_loader = train_dataloader
self.alexnet = model
self.cross_entropy = loss_fn
self.add_optimizer(optimizer)
def build(self,image,label):
logits = self.alexnet(image)
loss = self.cross_entropy(logits, label)
loss.backward()
return loss
alexnet_graph = AlexNetGraph()
class AlexNetEvalGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.val_data_loader = val_data_loader
self.alexnet = model
def build(self,image):
with flow.no_grad():
logits = self.alexnet(image)
predictions = logits.softmax()
return predictions, label
alexnet_eval_graph = AlexNetEvalGraph()
of_losses = []
all_samples = len(val_data_loader) * BATCH_SIZE
print_interval = 10
start_t = time.time()
for epoch in range(EPOCH_NUM):
model.train()
for b in range(len(train_dataloader)):
# oneflow graph train
image, label = train_dataloader()
image = image.to(DEVICE)
label = label.to(DEVICE)
loss = alexnet_graph(image,label)
if b % print_interval == 0:
l = loss.numpy()
of_losses.append(l)
print(
"epoch {} train iter {} oneflow loss {}".format(
epoch, b, l
)
)
end_t = time.time()
print("train time : {}".format(end_t - start_t))
# print("epoch %d train done, start validation" % epoch)
# model.eval()
# correct_of = 0.0
# for b in range(len(val_data_loader)):
# image, label = val_data_loader()
# start_t = time.time()
# image = image.to(DEVICE)
# predictions, label = alexnet_eval_graph(image)
# of_predictions = predictions.numpy()
# clsidxs = np.argmax(of_predictions, axis=1)
# label_nd = label.numpy()
# for i in range(BATCH_SIZE):
# if clsidxs[i] == label_nd[i]:
# correct_of += 1
# end_t = time.time()
# print("epoch %d, oneflow top1 val acc: %f" % (epoch, correct_of / all_samples))