In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms


In [2]:
dataset = datasets.ImageFolder(
    'dataset',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

In [3]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 65, 65])

In [4]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
)

In [5]:
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, 2)

In [6]:
device = torch.device('cuda')
model = model.to(device)

In [7]:
import datetime
today = datetime.date.today()
strToday = str(today.year) + '_' + str(today.month) + '_' + str(today.day)

RESNET_MODEL = 'Botline_CA_model_resnet18_' + strToday + '.pth'
TRT_MODEL = 'Botline_CA_model_resnet18_trt_' + strToday + '.pth'

In [8]:
NUM_EPOCHS = 50
best_accuracy = 0.0

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
    
    test_error_count = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))
    
    test_accuracy = 1.0 - float(test_error_count) / float(len(test_dataset))
    print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), RESNET_MODEL)
        best_accuracy = test_accuracy

0: 0.892308
1: 0.953846
2: 0.953846
3: 0.953846
4: 0.923077
5: 0.938462
6: 0.969231
7: 0.969231
8: 0.984615
9: 1.000000
10: 0.969231
11: 0.969231
12: 0.984615
13: 0.984615
14: 0.969231
15: 0.984615
16: 0.938462
17: 1.000000
18: 0.984615
19: 0.969231
20: 0.984615
21: 0.938462
22: 0.953846
23: 0.969231
24: 0.984615
25: 0.969231
26: 0.984615
27: 0.984615
28: 0.953846
29: 0.984615
30: 0.969231
31: 0.969231
32: 0.969231
33: 0.984615
34: 0.969231
35: 0.984615
36: 0.969231
37: 0.984615
38: 0.969231
39: 0.846154
40: 0.938462
41: 0.923077
42: 0.969231
43: 0.953846
44: 0.953846
45: 0.923077
46: 0.953846
47: 0.892308
48: 0.984615
49: 0.969231


In [9]:
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 2)
model = model.cuda().eval().half()

model.load_state_dict(torch.load(RESNET_MODEL))

<All keys matched successfully>

In [10]:
from torch2trt import torch2trt

data = torch.zeros((1, 3, 224, 224)).cuda().half()

model_trt = torch2trt(model, [data], fp16_mode=True)

ModuleNotFoundError: No module named 'tensorrt'

In [None]:
torch.save(model_trt.state_dict(), TRT_MODEL)