In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


class OrganoidDataset(Dataset):

    def __init__(self, transform=None):
        df = np.loadtxt('dfLocalityReward.csv', delimiter=',', dtype=np.float32, skiprows=1)
        self.n_samples = df.shape[0]

        self.x_data = df[:, 1:]
        self.y_data = df[:, [0]]

        self.transform = transform

    def __getitem__(self, index):
        sample = self.x_data[index], self.y_data[index]

        if self.transform:
            sample = self.transform(sample)

        return sample

    def __len__(self):
        return self.n_samples

class ToTensor:
    # Convert ndarrays to Tensors
    def __call__(self, sample):
        inputs, targets = sample
        return torch.from_numpy(inputs), torch.from_numpy(targets)

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
input_size = 1681 # 41x41
output_size = 1
num_epochs = 1000
batch_size = 70
learning_rate = 0.001

dataset = OrganoidDataset(transform=ToTensor())

train_loader = DataLoader(dataset=dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=2)

# Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.input_size = input_size
        self.relu = nn.ReLU()
        self.l1 = nn.Linear(input_size, 128)
        self.l2 = nn.Linear(128, 64)
        self.l3 = nn.Linear(64, output_size)
    
    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        out = self.relu(out)
        out = self.l3(out)
        out = self.relu(out)
        return out

model = NeuralNet(input_size, output_size).to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  


# Train the model
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        
        images = images.reshape(-1, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 5 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
            
# TO DO #############
# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    for images, labels in test_loader:
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)
        outputs = model(images)
        # max returns (value ,index)
        _, predicted = torch.max(outputs.data, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the network on the 10000 test images: {acc} %')

Epoch [1/1000], Step [5/5], Loss: 10670.6807
Epoch [2/1000], Step [5/5], Loss: 7921.2500
Epoch [3/1000], Step [5/5], Loss: 9741.2695
Epoch [4/1000], Step [5/5], Loss: 9445.0732
Epoch [5/1000], Step [5/5], Loss: 10038.8770
Epoch [6/1000], Step [5/5], Loss: 10118.9189
Epoch [7/1000], Step [5/5], Loss: 9243.9209
Epoch [8/1000], Step [5/5], Loss: 5741.2319
Epoch [9/1000], Step [5/5], Loss: 5366.1021
Epoch [10/1000], Step [5/5], Loss: 7281.7788
Epoch [11/1000], Step [5/5], Loss: 6052.3027
Epoch [12/1000], Step [5/5], Loss: 5175.8442
Epoch [13/1000], Step [5/5], Loss: 7527.5776
Epoch [14/1000], Step [5/5], Loss: 7169.1709
Epoch [15/1000], Step [5/5], Loss: 8364.0039
Epoch [16/1000], Step [5/5], Loss: 9111.3984
Epoch [17/1000], Step [5/5], Loss: 7184.4922
Epoch [18/1000], Step [5/5], Loss: 7110.1079
Epoch [19/1000], Step [5/5], Loss: 6887.3037
Epoch [20/1000], Step [5/5], Loss: 11486.6807
Epoch [21/1000], Step [5/5], Loss: 8461.1445
Epoch [22/1000], Step [5/5], Loss: 7508.9609
Epoch [23/1000]

Epoch [182/1000], Step [5/5], Loss: 1011.9536
Epoch [183/1000], Step [5/5], Loss: 1112.0081
Epoch [184/1000], Step [5/5], Loss: 1026.3024
Epoch [185/1000], Step [5/5], Loss: 639.7203
Epoch [186/1000], Step [5/5], Loss: 1257.6460
Epoch [187/1000], Step [5/5], Loss: 1082.8817
Epoch [188/1000], Step [5/5], Loss: 1199.4047
Epoch [189/1000], Step [5/5], Loss: 1348.9312
Epoch [190/1000], Step [5/5], Loss: 1089.7516
Epoch [191/1000], Step [5/5], Loss: 1316.3259
Epoch [192/1000], Step [5/5], Loss: 1091.1156
Epoch [193/1000], Step [5/5], Loss: 1185.8165
Epoch [194/1000], Step [5/5], Loss: 1401.0536
Epoch [195/1000], Step [5/5], Loss: 1052.0228
Epoch [196/1000], Step [5/5], Loss: 967.6056
Epoch [197/1000], Step [5/5], Loss: 1358.7236
Epoch [198/1000], Step [5/5], Loss: 1454.3827
Epoch [199/1000], Step [5/5], Loss: 555.0566
Epoch [200/1000], Step [5/5], Loss: 1130.2817
Epoch [201/1000], Step [5/5], Loss: 748.0865
Epoch [202/1000], Step [5/5], Loss: 1211.2448
Epoch [203/1000], Step [5/5], Loss: 11

Epoch [363/1000], Step [5/5], Loss: 671.1443
Epoch [364/1000], Step [5/5], Loss: 418.3719
Epoch [365/1000], Step [5/5], Loss: 675.7605
Epoch [366/1000], Step [5/5], Loss: 715.6686
Epoch [367/1000], Step [5/5], Loss: 701.3871
Epoch [368/1000], Step [5/5], Loss: 657.8614
Epoch [369/1000], Step [5/5], Loss: 566.2260
Epoch [370/1000], Step [5/5], Loss: 573.1192
Epoch [371/1000], Step [5/5], Loss: 730.3590
Epoch [372/1000], Step [5/5], Loss: 633.9476
Epoch [373/1000], Step [5/5], Loss: 709.5140
Epoch [374/1000], Step [5/5], Loss: 626.4027
Epoch [375/1000], Step [5/5], Loss: 859.1207
Epoch [376/1000], Step [5/5], Loss: 789.5765
Epoch [377/1000], Step [5/5], Loss: 543.1583
Epoch [378/1000], Step [5/5], Loss: 654.1031
Epoch [379/1000], Step [5/5], Loss: 482.3334
Epoch [380/1000], Step [5/5], Loss: 539.1484
Epoch [381/1000], Step [5/5], Loss: 598.7173
Epoch [382/1000], Step [5/5], Loss: 617.3027
Epoch [383/1000], Step [5/5], Loss: 762.7953
Epoch [384/1000], Step [5/5], Loss: 666.0994
Epoch [385

Epoch [546/1000], Step [5/5], Loss: 255.6402
Epoch [547/1000], Step [5/5], Loss: 356.3521
Epoch [548/1000], Step [5/5], Loss: 419.8422
Epoch [549/1000], Step [5/5], Loss: 284.3648
Epoch [550/1000], Step [5/5], Loss: 478.8397
Epoch [551/1000], Step [5/5], Loss: 326.2993
Epoch [552/1000], Step [5/5], Loss: 659.5039
Epoch [553/1000], Step [5/5], Loss: 464.8211
Epoch [554/1000], Step [5/5], Loss: 384.7489
Epoch [555/1000], Step [5/5], Loss: 450.6216
Epoch [556/1000], Step [5/5], Loss: 478.4450
Epoch [557/1000], Step [5/5], Loss: 374.5646
Epoch [558/1000], Step [5/5], Loss: 625.8054
Epoch [559/1000], Step [5/5], Loss: 668.8673
Epoch [560/1000], Step [5/5], Loss: 343.3799
Epoch [561/1000], Step [5/5], Loss: 616.9918
Epoch [562/1000], Step [5/5], Loss: 508.0910
Epoch [563/1000], Step [5/5], Loss: 473.1666
Epoch [564/1000], Step [5/5], Loss: 443.8409
Epoch [565/1000], Step [5/5], Loss: 496.5699
Epoch [566/1000], Step [5/5], Loss: 398.4383
Epoch [567/1000], Step [5/5], Loss: 482.5097
Epoch [568

Epoch [729/1000], Step [5/5], Loss: 364.8259
Epoch [730/1000], Step [5/5], Loss: 379.9243
Epoch [731/1000], Step [5/5], Loss: 471.3517
Epoch [732/1000], Step [5/5], Loss: 458.9682
Epoch [733/1000], Step [5/5], Loss: 513.3681
Epoch [734/1000], Step [5/5], Loss: 378.1187
Epoch [735/1000], Step [5/5], Loss: 440.5431
Epoch [736/1000], Step [5/5], Loss: 556.3600
Epoch [737/1000], Step [5/5], Loss: 393.1201
Epoch [738/1000], Step [5/5], Loss: 479.9062
Epoch [739/1000], Step [5/5], Loss: 344.6320
Epoch [740/1000], Step [5/5], Loss: 347.1778
Epoch [741/1000], Step [5/5], Loss: 373.9508
Epoch [742/1000], Step [5/5], Loss: 667.5547
Epoch [743/1000], Step [5/5], Loss: 435.9422
Epoch [744/1000], Step [5/5], Loss: 493.2754
Epoch [745/1000], Step [5/5], Loss: 528.8316
Epoch [746/1000], Step [5/5], Loss: 434.8668
Epoch [747/1000], Step [5/5], Loss: 318.9039
Epoch [748/1000], Step [5/5], Loss: 386.2627
Epoch [749/1000], Step [5/5], Loss: 431.8095
Epoch [750/1000], Step [5/5], Loss: 494.9550
Epoch [751

Epoch [912/1000], Step [5/5], Loss: 320.9514
Epoch [913/1000], Step [5/5], Loss: 462.0907
Epoch [914/1000], Step [5/5], Loss: 536.8870
Epoch [915/1000], Step [5/5], Loss: 444.9052
Epoch [916/1000], Step [5/5], Loss: 407.3192
Epoch [917/1000], Step [5/5], Loss: 318.6121
Epoch [918/1000], Step [5/5], Loss: 265.2885
Epoch [919/1000], Step [5/5], Loss: 386.5047
Epoch [920/1000], Step [5/5], Loss: 317.0707
Epoch [921/1000], Step [5/5], Loss: 432.9178
Epoch [922/1000], Step [5/5], Loss: 586.7516
Epoch [923/1000], Step [5/5], Loss: 405.1638
Epoch [924/1000], Step [5/5], Loss: 616.8998
Epoch [925/1000], Step [5/5], Loss: 279.4110
Epoch [926/1000], Step [5/5], Loss: 462.3993
Epoch [927/1000], Step [5/5], Loss: 345.8147
Epoch [928/1000], Step [5/5], Loss: 305.4797
Epoch [929/1000], Step [5/5], Loss: 343.8955
Epoch [930/1000], Step [5/5], Loss: 300.5880
Epoch [931/1000], Step [5/5], Loss: 301.7928
Epoch [932/1000], Step [5/5], Loss: 350.9189
Epoch [933/1000], Step [5/5], Loss: 374.9013
Epoch [934