In [8]:
import torch
import nvtripy as tp

class TripyLeNet(tp.Module):
    def __init__(self, dtype=tp.float32):
        super().__init__()

        # те же размеры каналов и ядер, что и в PyTorch-версии
        self.conv1 = tp.Conv(
            in_channels=3,
            out_channels=6,
            kernel_dims=(5, 5),
            dtype=dtype,
        )
        self.conv2 = tp.Conv(
            in_channels=6,
            out_channels=16,
            kernel_dims=(5, 5),
            dtype=dtype,
        )

        self.fc1 = tp.Linear(16 * 5 * 5, 120, dtype=dtype)
        self.fc2 = tp.Linear(120, 84, dtype=dtype)
        self.out = tp.Linear(84, 10, dtype=dtype)

    def forward(self, x):
        # вход: (N, 3, 32, 32)
        x = self.conv1(x)
        x = tp.relu(x)
        x = tp.maxpool(x, kernel_dims=(2, 2), stride=(2, 2))

        x = self.conv2(x)
        x = tp.relu(x)
        x = tp.maxpool(x, kernel_dims=(2, 2), stride=(2, 2))

        # выпрямление фичей: (N, 16*5*5)
        x = tp.reshape(x, (x.shape[0], -1))

        x = self.fc1(x)
        x = tp.relu(x)

        x = self.fc2(x)
        x = tp.relu(x)

        x = self.out(x)
        return x

In [9]:
def convert_state_torch_to_tripy(torch_state):
    tripy_state = {}
    for name, param in torch_state.items():
        np_value = param.detach().cpu().numpy().astype("float32")
        tripy_state[name] = tp.Tensor(np_value)
    return tripy_state
torch_state = torch.load("lenet_init.pth")
tripy_state = convert_state_torch_to_tripy(torch_state)


tripy_model = TripyLeNet()
tripy_model.load_state_dict(tripy_state)
x = tp.ones((1, 3, 32, 32))
print(tripy_model(x))

tensor(
    [[-0.0875086, 0.150313, -0.0598906, 0.00871439, -0.0239175, 0.0357586, 0.111011, 0.0426953, -0.00735878, 0.0625067]], 
    dtype=float32, loc=gpu:0, shape=(1, 10))
