In [15]:
import timm
import torch
import torch.nn as nn
import copy
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.metrics import confusion_matrix
from torchvision import transforms, datasets

In [16]:
print(timm.__version__)
print(torch.__version__)

0.9.12
2.1.1


In [17]:
class CNN_5Layer(nn.Module):
    def __init__(self):
        super(CNN_5Layer, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),
        )

        dummy_input = torch.randn(1, 3, 224, 224)
        conv_output_size = self._get_conv_output_size(dummy_input)

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(conv_output_size, 4, bias=True),
            nn.Softmax(dim=1)
        )

        self.l1_regularizer = nn.L1Loss()

    def _get_conv_output_size(self, x):
        with torch.no_grad():
            conv_output = self.conv_layers(x)
        return conv_output.view(x.size(0), -1).shape[1]

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

In [18]:
class CNN_6Layer(nn.Module):
    def __init__(self):
        super(CNN_6Layer, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),
        )

        dummy_input = torch.randn(1, 3, 224, 224)
        conv_output_size = self._get_conv_output_size(dummy_input)

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(conv_output_size, 4, bias=True),
            nn.Softmax(dim=1)
        )

        self.l1_regularizer = nn.L1Loss()

    def _get_conv_output_size(self, x):
        with torch.no_grad():
            conv_output = self.conv_layers(x)
        return conv_output.view(x.size(0), -1).shape[1]

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

In [19]:
def create_model(model_name):
    if model_name == 'vit-tiny':
        model = timm.create_model('vit_tiny_patch16_224', pretrained=False, num_classes=4)
        model_path = '../experiments/non-keyframe/vit-tiny/vit-tiny-130-epochs-early-stopping-tiny.h5'

    if model_name == 'vit-small': 
        model = model = timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=4)
        model_path = '../non-keyframe/vit-small/vit-small-130-epochs-early-stopping-small.h5'
        
    if model_name == 'vit-base': 
        model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=4)
        model_path = '../non-keyframe/vit-base/vit-base-130-epochs-early-stopping-base.h5'

    if model_name == 'resnet-18': 
        model = timm.create_model('resnet18', pretrained=False, num_classes=4)
        model_path = '../non-keyframe/resnet-18/resnet-18-130-epochs-early-stopping-resnet18.h5'
    
    if model_name == 'resnet-34': 
        model = timm.create_model('resnet34', pretrained=False, num_classes=4)
        model_path = '../non-keyframe/resnet-34/resnet-34-130-epochs-early-stopping-resnet34.h5'
    
    if model_name == 'cnn-5-layer': 
        model = CNN_5Layer() 
        model_path = '../non-keyframe/cnn-5-layer/cnn-5-layer-130-epochs-early-stopping-with-regularization-5-layer'
    
    if model_name == 'cnn-6-layer': 
        model = CNN_6Layer()
        model_path = '../non-keyframe/cnn-6-layer/cnn-6-layer-130-epochs-early-stopping-with-regularization-6-layer'

    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)

    return model

In [20]:
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

dataset = datasets.ImageFolder(root='../energy-images-empirical', transform=transform)

total_size = len(dataset)
train_size = int(total_size * 0.8) 
validation_size = int(total_size * 0.1) 
test_size = total_size - train_size - validation_size
generator = torch.Generator().manual_seed(0) 

data_loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [21]:
images = torch.stack([image for image, _ in dataset], dim=0)
labels = torch.tensor([label for _, label in dataset], dtype=torch.int64)

In [22]:
def test(model_name):
    model = create_model(model_name)

    criterion = torch.nn.CrossEntropyLoss()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model.to(device)

    model.eval()

    test_loss = 0.0
    test_correct = 0
    test_total = 0
    all_predictions = []
    all_labels = []
    total_accuracy = 0

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(data_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            print(f"Class {batch_idx}, Loss: {loss.item():.6f}, Accuracy: {100 * test_correct / test_total:.2f}%")
        
    total_accuracy += 100 * test_correct / test_total

    print(f'Model — {model_name}')
    print(f'Average Accuracy — {total_accuracy/4}')

In [23]:
model_types = ['vit-tiny', 'vit-small', 'vit-base', 'resnet-18', 'resnet-34', 'cnn-5-layer', 'cnn-6-layer']
for model in model_types:
    print(model)
    test(model)

Using device: cpu
Class 0, Loss: 5.575579, Accuracy: 23.44%
Class 1, Loss: 5.473075, Accuracy: 25.00%
Class 2, Loss: 5.270370, Accuracy: 25.52%
Class 3, Loss: 5.168959, Accuracy: 26.56%
Class 4, Loss: 5.940934, Accuracy: 26.25%
Class 5, Loss: 5.728189, Accuracy: 24.74%
Class 6, Loss: 6.375771, Accuracy: 23.21%
Class 7, Loss: 6.290821, Accuracy: 22.46%
Class 8, Loss: 5.398467, Accuracy: 22.57%
Class 9, Loss: 5.288901, Accuracy: 23.12%
Class 10, Loss: 5.187339, Accuracy: 23.30%
Class 11, Loss: 6.097549, Accuracy: 22.79%
Class 12, Loss: 5.602042, Accuracy: 22.72%
Class 13, Loss: 5.379881, Accuracy: 23.21%
Class 14, Loss: 5.822414, Accuracy: 23.23%
Class 15, Loss: 5.840580, Accuracy: 22.56%
Class 16, Loss: 6.033894, Accuracy: 22.43%
Class 17, Loss: 5.251490, Accuracy: 22.66%
Class 18, Loss: 5.558633, Accuracy: 22.78%
Class 19, Loss: 6.358394, Accuracy: 22.50%
Class 20, Loss: 5.593822, Accuracy: 22.54%


KeyboardInterrupt: 