In [1]:
# Install required packages.
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch_geometric
from torch_geometric.data import DataLoader
import torch.optim as optim


torch.multiprocessing.set_sharing_strategy('file_system')

In [2]:
# Définir les transformations pour les images
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [3]:
# Télécharger l'ensemble de données CIFAR-10 et le charger
trainset = torchvision.datasets.CIFAR10(root='C:/Users/hp/Desktop/Python Projects/cifar-10-batches-py', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='C:/Users/hp/Desktop/Python Projects/cifar-10-batches-py', train=False,
                                       download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)


# Définir les classes de l'ensemble de données CIFAR-10
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(len(trainset))
print(trainset[10])

Files already downloaded and verified
Files already downloaded and verified
50000
(tensor([[[-0.5294, -0.5294, -0.5373,  ..., -0.6863, -0.7569, -0.7882],
         [-0.5451, -0.5373, -0.5451,  ..., -0.6863, -0.7490, -0.7804],
         [-0.5686, -0.5608, -0.5529,  ..., -0.6941, -0.7412, -0.7647],
         ...,
         [-0.3804, -0.3725, -0.3647,  ..., -0.5294, -0.5922, -0.6235],
         [-0.3647, -0.3647, -0.3569,  ..., -0.5294, -0.5843, -0.6157],
         [-0.3569, -0.3569, -0.3490,  ..., -0.5294, -0.5843, -0.6157]]]), 4)


In [4]:
import torch.nn as nn
from gcn_lib.dense.torch_vertex import DynConv2d
from torch._six import inf
# gcn_lib is downloaded from https://github.com/lightaime/deep_gcns_torch

class GrapherModule(nn.Module):
    """Grapher module with graph conv and FC layers
    """
    def _init_(self, in_channels, hidden_channels, k=9, dilation=1, drop_path=0.0):
        super(GrapherModule, self)._init_()
        self.fc1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels),
        )
        self.graph_conv = nn.Sequential(
            DynConv2d(in_channels, hidden_channels, k, dilation, act=None),
            nn.BatchNorm2d(hidden_channels),
            nn.GELU(),
        )
        self.fc2 = nn.Sequential(
            nn.Conv2d(hidden_channels, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels),
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.reshape(B, C, -1, 1).contiguous()
        shortcut = x
        x = self.fc1(x)
        x = self.graph_conv(x)
        x = self.fc2(x)
        x = self.drop_path(x) + shortcut
        return x.reshape(B, C, H, W)

class FFNModule(nn.Module):
    """Feed-forward Network
    """
    def _init_(self, in_channels, hidden_channels, drop_path=0.0):
        super(FFNModule, self)._init_()
        self.fc1 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(hidden_channels),
            nn.GELU()
        )
        self.fc2 = nn.Sequential(
            nn.Conv2d(hidden_channels, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels),
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        shortcut = x
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.drop_path(x) + shortcut
        return x

class ViGBlock(nn.Module):
    """ViG block with Grapher and FFN modules"""
    def __init__(self, channels, k, dilation, drop_path=0.0):
        super(ViGBlock, self).__init__()
        self.grapher = GrapherModule(channels, channels * 2, k, dilation, drop_path)
        self.ffn = FFNModule(channels, channels * 4, drop_path)

    def forward(self, x):
        x = self.grapher(x)
        x = self.ffn(x)
        return x


ModuleNotFoundError: No module named 'torch._six'

In [5]:
# print the model
print(model)
# définir la fonction de perte
criterion = nn.CrossEntropyLoss()

# définir l'optimiseur
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train(model, device, trainloader, optimizer, criterion, epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(trainloader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()

        train_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(trainloader.dataset)} '
                  f'({100. * batch_idx / len(trainloader):.0f}%)]\tLoss: {loss.item():.6f}')

    train_loss /= len(train_loader)
    train_acc = 100. * correct / total

    print(f'Train Epoch: {epoch} Average loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%')

    return train_loss, train_acc

NameError: name 'model' is not defined

In [6]:
if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # define the model
    model = ViGBlock(in_channels=64, k=9, dilation=1, drop_path=0.1)
    model.to(device)

    # define the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # train the model
    train(model, device, trainloader, optimizer, criterion, epoch=10)

NameError: name 'ViGBlock' is not defined