In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from dataset import SubsetCIFAR
classes_to_keep = {0, 1} # set for fast existence check
# Create a new dataset that only keeps the desired classes


In [12]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Download and load the training data
trainset = SubsetCIFAR(root='./data', train=True, download=True, indices=classes_to_keep, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
print(trainset.targets)
# Download and load the test data
testset = SubsetCIFAR(root='./data', train=False, download=True, indices=classes_to_keep,transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
print(testset.targets)


Files already downloaded and verified
[1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1

In [18]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(32*32*3, 128)
        self.fc2 = nn.Linear(128, 1)
        # self.fc3 = nn.Linear(256, 10)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = x.view(-1, 32*32*3)  # Flatten the image
        x = self.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        # x = self.fc3(x)
        return x

    # Initialize the model, loss function, and optimizer
model = SimpleNN()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


In [24]:
def train(model, trainloader, criterion, optimizer, epochs=2):
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs,labels in trainloader:
            #inputs, labels = data
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            # print("op",outputs.squeeze(dim=1))
            # print("label",labels.float())
            loss = criterion(outputs.squeeze(dim=1), labels.float())
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            # if i % 100 == 99:  # print every 100 mini-batches
            #     print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 100:.4f}')
            #     running_loss = 0.0

    # Evaluation loop
def evaluate(model, testloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            print(outputs)
            print(labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            break
    
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')



In [25]:
train(model, trainloader, criterion, optimizer, epochs=2)

    # Evaluate the model
    

In [26]:
evaluate(model, testloader)

tensor([[0.1939],
        [0.7193],
        [0.8397],
        [0.2417],
        [0.0259],
        [0.3138],
        [0.7764],
        [0.0521],
        [0.7051],
        [0.9083],
        [0.5125],
        [0.7711],
        [0.7559],
        [0.4060],
        [0.2477],
        [0.1097],
        [0.8064],
        [0.9080],
        [0.2310],
        [0.6710],
        [0.1427],
        [0.7451],
        [0.7262],
        [0.6227],
        [0.6075],
        [0.2450],
        [0.2600],
        [0.8566],
        [0.2951],
        [0.2559],
        [0.5820],
        [0.0702],
        [0.5698],
        [0.6966],
        [0.3913],
        [0.7590],
        [0.4099],
        [0.3288],
        [0.9714],
        [0.0635],
        [0.2587],
        [0.7349],
        [0.6956],
        [0.0813],
        [0.5736],
        [0.2347],
        [0.5493],
        [0.4026],
        [0.5974],
        [0.4013],
        [0.1762],
        [0.4301],
        [0.7108],
        [0.8886],
        [0.4151],
        [0

In [None]:
class SubsetCIFAR(torchvision.datasets.CIFAR10):
    def __init__(self, root, train=True, transform=None, download=True, indices=None):
        super().__init__(root, train=train, transform=transform, download=download)
        keepDeez = {indx:klass for indx, klass in enumerate(self.targets) if klass in indices} 
        self.targets = list(keepDeez.values())
        self.data = self.data[list(keepDeez.keys())]