In [34]:
import torch
import nvtripy as tp
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms


class TripyLeNet(tp.Module):
    def __init__(self, dtype=tp.float32):
        super().__init__()

        self.conv1 = tp.Conv(
            in_channels=3,
            out_channels=6,
            kernel_dims=(5, 5),
            dtype=dtype,
        )
        self.conv2 = tp.Conv(
            in_channels=6,
            out_channels=16,
            kernel_dims=(5, 5),
            dtype=dtype,
        )

        self.fc1 = tp.Linear(16 * 5 * 5, 120, dtype=dtype)
        self.fc2 = tp.Linear(120, 84, dtype=dtype)
        self.out = tp.Linear(84, 10, dtype=dtype)

    def forward(self, x):
        # вход: (N, 3, 32, 32)
        x = self.conv1(x)
        x = tp.relu(x)
        x = tp.maxpool(x, kernel_dims=(2, 2), stride=(2, 2))

        x = self.conv2(x)
        x = tp.relu(x)
        x = tp.maxpool(x, kernel_dims=(2, 2), stride=(2, 2))

        # выпрямление фичей: (N, 16*5*5)
        x = tp.reshape(x, (x.shape[0], -1))

        x = self.fc1(x)
        x = tp.relu(x)

        x = self.fc2(x)
        x = tp.relu(x)

        x = self.out(x)
        return x


def convert_state_torch_to_tripy(torch_state):
    tripy_state = {}
    for name, param in torch_state.items():
        np_value = param.detach().cpu().numpy().astype("float32")
        tripy_state[name] = tp.Tensor(np_value)
    return tripy_state


torch_state = torch.load("lenet_cifar10_best_pytorch.pth")
tripy_state = convert_state_torch_to_tripy(torch_state)

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)

        self.fc1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.out = nn.Linear(in_features=84, out_features=10)

    def forward(self, t):
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = t.reshape(-1, 16 * 5 * 5)
        t = self.fc1(t)
        t = F.relu(t)

        t = self.fc2(t)
        t = F.relu(t)

        t = self.out(t)
        return t

mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=train_transform
)

test_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=test_transform
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=200,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=200,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

Device: cuda


In [58]:
criterion = nn.CrossEntropyLoss()

@torch.no_grad()
def evaluate(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0

    for images, labels in data_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        outputs = model(images)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / total
    acc = correct / total
    return avg_loss, acc

def to_tripy_tensor(x_torch):
    list_x = x_torch.tolist()
    return tp.Tensor(list_x)

def to_torch_tensor(x_tripy):
    list_x = x_tripy.tolist()
    return torch.Tensor(list_x)

@torch.no_grad()
def evaluate_tripy(tripy_model, data_loader):
    total = 0
    correct = 0
    running_loss = 0.0

    for images, labels in data_loader:
        images = images
        labels = labels

        outputs = to_torch_tensor(tripy_model(to_tripy_tensor(images)))
        loss = criterion(outputs, labels)

        running_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / total
    acc = correct / total
    return avg_loss, acc


In [59]:
torch_model = LeNet().to(device)
torch_model.load_state_dict(torch_state)

tripy_model = TripyLeNet()
tripy_model.load_state_dict(tripy_state)

print(evaluate(torch_model, test_loader, device))
print(evaluate_tripy(tripy_model, test_loader))

(0.7994318306446075, 0.727)
(0.7994540703296661, 0.7269)
