In [2]:
import timm
import torch
import torch.nn as nn
import copy
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.metrics import confusion_matrix
from torchvision import transforms, datasets

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print(timm.__version__)
print(torch.__version__)

0.9.12
2.1.1


In [4]:
class CNN_5Layer(nn.Module):
    def __init__(self):
        super(CNN_5Layer, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),
        )

        dummy_input = torch.randn(1, 3, 224, 224)
        conv_output_size = self._get_conv_output_size(dummy_input)

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(conv_output_size, 4, bias=True),
            nn.Softmax(dim=1)
        )

        self.l1_regularizer = nn.L1Loss()

    def _get_conv_output_size(self, x):
        with torch.no_grad():
            conv_output = self.conv_layers(x)
        return conv_output.view(x.size(0), -1).shape[1]

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

In [5]:
class CNN_6Layer(nn.Module):
    def __init__(self):
        super(CNN_6Layer, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.5),
        )

        dummy_input = torch.randn(1, 3, 224, 224)
        conv_output_size = self._get_conv_output_size(dummy_input)

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(conv_output_size, 4, bias=True),
            nn.Softmax(dim=1)
        )

        self.l1_regularizer = nn.L1Loss()

    def _get_conv_output_size(self, x):
        with torch.no_grad():
            conv_output = self.conv_layers(x)
        return conv_output.view(x.size(0), -1).shape[1]

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

In [11]:
def create_model(model_name):
    if model_name == 'vit-tiny':
        model = timm.create_model('vit_tiny_patch16_224', pretrained=False, num_classes=4)
        model_path = '../experiments/non-keyframe/vit-tiny/vit-tiny-130-epochs-early-stopping-tiny.h5'

    if model_name == 'vit-small': 
        model = model = timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=4)
        model_path = '../experiments/non-keyframe/vit-small/vit-small-130-epochs-early-stopping-small.h5'
        
    if model_name == 'vit-base': 
        model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=4)
        model_path = '../experiments/non-keyframe/vit-base/vit-base-130-epochs-early-stopping-base.h5'

    if model_name == 'resnet-18': 
        model = timm.create_model('resnet18', pretrained=False, num_classes=4)
        model_path = '../experiments/non-keyframe/resnet-18/resnet-18-130-epochs-early-stopping-resnet18.h5'
    
    if model_name == 'resnet-34': 
        model = timm.create_model('resnet34', pretrained=False, num_classes=4)
        model_path = '../experiments/non-keyframe/resnet-34/resnet-34-130-epochs-early-stopping-resnet34.h5'
    
    if model_name == 'cnn-5-layer': 
        model = CNN_5Layer() 
        model_path = '../experiments/non-keyframe/cnn-5-layer/cnn-5-layer-130-epochs-early-stopping-with-regularization-5-layer'
    
    if model_name == 'cnn-6-layer': 
        model = CNN_6Layer()
        model_path = '../experiments/non-keyframe/cnn-6-layer/cnn-6-layer-130-epochs-early-stopping-with-regularization-6-layer'

    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)

    return model

In [7]:
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

dataset = datasets.ImageFolder(root='../energy-images-empirical', transform=transform)

total_size = len(dataset)
train_size = int(total_size * 0.8) 
validation_size = int(total_size * 0.1) 
test_size = total_size - train_size - validation_size
generator = torch.Generator().manual_seed(0) 

data_loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [8]:
images = torch.stack([image for image, _ in dataset], dim=0)
labels = torch.tensor([label for _, label in dataset], dtype=torch.int64)

In [13]:
def test(model_name):
    model = create_model(model_name)

    criterion = torch.nn.CrossEntropyLoss()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model.to(device)

    model.eval()

    test_loss = 0.0
    test_correct = 0
    test_total = 0
    all_predictions = []
    all_labels = []
    total_accuracy = 0

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(data_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            print(f"Batch {batch_idx}, Loss: {loss.item():.6f}, Accuracy: {100 * test_correct / test_total:.2f}%")
        
    total_accuracy += 100 * test_correct / test_total

    print(f'Model — {model_name}')
    print(f'Average Accuracy — {total_accuracy}')
    print('+=' * 20)

In [14]:
model_types = ['vit-tiny', 'vit-small', 'vit-base', 'resnet-18', 'resnet-34', 'cnn-5-layer', 'cnn-6-layer']
for model in model_types:
    test(model)

Using device: cpu
Batch 0, Loss: 6.461430, Accuracy: 18.75%
Batch 1, Loss: 5.791875, Accuracy: 22.66%
Batch 2, Loss: 5.702439, Accuracy: 22.92%
Batch 3, Loss: 6.059984, Accuracy: 22.27%
Batch 4, Loss: 5.459346, Accuracy: 23.44%
Batch 5, Loss: 5.828105, Accuracy: 22.66%
Batch 6, Loss: 5.864913, Accuracy: 22.54%
Batch 7, Loss: 6.726446, Accuracy: 21.09%
Batch 8, Loss: 5.603005, Accuracy: 21.18%
Batch 9, Loss: 5.681521, Accuracy: 20.62%
Batch 10, Loss: 5.472759, Accuracy: 21.16%
Batch 11, Loss: 5.676647, Accuracy: 20.96%
Batch 12, Loss: 5.940926, Accuracy: 21.03%
Batch 13, Loss: 5.410361, Accuracy: 20.87%
Batch 14, Loss: 5.641367, Accuracy: 20.83%
Batch 15, Loss: 5.454420, Accuracy: 21.39%
Batch 16, Loss: 5.836728, Accuracy: 21.32%
Batch 17, Loss: 6.166385, Accuracy: 21.09%
Batch 18, Loss: 5.986181, Accuracy: 20.89%
Batch 19, Loss: 5.383593, Accuracy: 21.09%
Batch 20, Loss: 5.188151, Accuracy: 21.35%
Batch 21, Loss: 5.403562, Accuracy: 21.66%
Batch 22, Loss: 5.067591, Accuracy: 22.08%
Bat