In [16]:
"""
Erik Wallin
Task 2

Using autoencoders, we get an unsupervised network that learns a representation of the data in 
the small bottle-neck layer, the latent space. The latent space gives us a dimensionality
reduction, which shows that it has learned some representation of the data.

For anomaly detection we need an unsupervised training method since we can't easily label wether an
experimental image has a DM-substructure or not. If we could, that would somewhat defeat the purpose. 

I will train the autoencoder on both images with dark matter-substructure and without, so that one can use
anomaly detection to find experimental images that does not fit either these models. 

One could also train two such autoencoders on only one dataset each, to compare how they 
reconstruct the experimental data. 

I did not see this as a task about classification, i.e. whether or not a certain image has some
dark matter substructure or not. For that some supervised classification method should be used. 

This proof of concept uses a purely linear network, with the AdamW optimizer, a SGD-like variant with weight-decay
which often shows good results for quick learning. 

"""

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

data_path ="../lenses"

dataset = torchvision.datasets.ImageFolder(
    root=data_path,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.Grayscale(1),
        torchvision.transforms.Resize((50,50)),
        torchvision.transforms.ToTensor()
        ])
)


train,_ = torch.utils.data.random_split(dataset, [8000,2000])


#The data is naturally in a suitable range from 0 to 1 after being grayscaled. Also,
#the mean value of the data is somewhere around 'black', so it doesn't need to be normalized much.

train_loader = torch.utils.data.DataLoader(train,
                                          batch_size=1,
                                          shuffle=True,
                                          )

validation_loader = torch.utils.data.DataLoader(train,
                                          batch_size=1,
                                          shuffle=True,
                                          )



"""
One should figure out some more intelligent loss function, since most of the image will be 
black space quite naturally. One could give it more weight to areas with actual activity. 
"""
lossfunc = nn.MSELoss()


"""
Implements a 50*50 - 1000 - 1000 - 500 - 1000- 1000 - 50*50 linear network
Thats a 5-time dimensionality reduction to the latent space. 
With time one could do some nice hyper-parameter scans to find an optimal network structure. 
"""
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        self.n1 = nn.Linear(50*50,1000)
        self.n2 = nn.Linear(1000,1000)
        self.n3 = nn.Linear(1000,500)
        self.n4 = nn.Linear(500,1000)
        self.n5 = nn.Linear(1000,1000)
        self.n6 = nn.Linear(1000,50*50)
        
        self.func = nn.Sigmoid()
        
    def forward(self, x):
        encoded = self.n3(self.func(self.n2(self.func(self.n1(x)))))
        return self.n6(self.func(self.n5(self.func(self.n4(self.func(encoded))))))

model = Network()
optimizer = optim.Adam(model.parameters(), lr=1e-6)

In [17]:
lossavg = 0

"""
One epoch of training through the training data. And periodically prints the average loss 
(on the training data). It is quite slow, as there are many weights in the network,
due to our high input dimension. 
"""
for i, data in enumerate(train_loader, 0):
    inputs = torch.flatten(data[0])
    
    optimizer.zero_grad()
    
    outputs = model.forward(inputs)
    
    loss = lossfunc(outputs, inputs)
    loss.backward()
    optimizer.step()
    lossavg+= loss.item()
    if i % 100 == 0 and i != 0:
        print(lossavg/100)
        lossavg=0

0.07815547697246075
0.05660406120121479
0.04258782185614109
0.03237452702596784
0.024387747403234242
0.018783607957884668
0.014794549737125635
0.011167375966906548
0.008862074352800847
0.006958424090407788
0.005597779592499137
0.004511185938026756
0.003842435907572508
0.0032210370036773383
0.002696759554091841
0.0022303017100784926
0.00209053571568802
0.0018259528710041196
0.001771223340765573
0.0016263902379432694
0.001558453066390939
0.0016015682026045397
0.001496130115701817
0.0015366951160831377
0.0014643413631711154
0.0015739518363261595
0.001471261830884032
0.0013547899192781188
0.0013891249499283732
0.0015452419803477823
0.0013551101263146847
0.0013223755138460547
0.001534765813848935
0.0014358618465485052
0.0014101557707181201
0.0014366814238019289
0.0015892190803424456
0.0014738166576717048
0.0014629695599433035
0.0015501493751071394
0.0013248483563074842
0.0012914605240803212
0.0011914023023564368
0.0015264532389119268
0.0016372486966429278
0.0015044055751059205
0.00151643355

The training error easily converges to around 1E-3. As does the validation error as seen below. 

In [18]:
for i, data in enumerate(validation_loader, 0):
    inputs = torch.flatten(data[0])
    
    outputs = model.forward(inputs)
    
    loss = lossfunc(outputs, inputs)

    lossavg+= loss.item()
    if i % 100 == 0  and i != 0:
        print(lossavg/100)
        lossavg=0

0.0029824086019652894
0.0014063676772639155
0.001421797553775832
0.0015576968158711678
0.0015116632322315126
0.001532444479817059
0.0015041677915723995
0.0012812459736596792
0.0013474100246094168
0.001473102539894171
0.0014135721075581387
0.0014248162019066513
0.001583937234536279
0.0015198809604044072
0.0014858728897524998
0.001505210014001932
0.0013834289537044242
0.0015335165857686662
0.001503364919917658
0.0014934874454047532
0.0016648168058600277
0.0015268234658287838
0.0014254972751950846
0.0014867703424533828
0.0014334265748038887
0.0014942449645604938
0.0015203773477696814
0.0015014386596158148
0.0014792031282559036
0.0013540465600090101
0.0014885884837713094
0.0014898042191634885
0.0013921119458973408
0.0014424055145354942
0.0013793741242261604
0.0015112112552742474
0.0013691743827075698
0.0014268266048748047
0.001461311118910089
0.0017029963858658447
0.0014592290244763717
0.001429881940712221
0.0013525637332350016
0.001461145468056202
0.0013899812774616294
0.00131992403476033