In [338]:
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd

In [339]:
# Read Data
spectrum_train = pd.read_excel('spectrum_train.xlsx')
spectrum_test = pd.read_excel('spectrum_valid.xlsx')
temp_train = pd.read_excel('temp_train.xlsx')
temp_test = pd.read_excel('temp_valid.xlsx')

In [340]:
spectrum_train_scaled = spectrum_train.multiply(10**12)
spectrum_test_scaled = spectrum_test.multiply(10**12)

In [341]:
# determine the supported device
def get_device():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu') # don't have GPU 
    return device

# convert a df to tensor to be used in pytorch
def df_to_tensor(df):
    device = get_device()
    return torch.from_numpy(df.values).float().to(device)

In [342]:
input_data = df_to_tensor(temp_train)
output_data = df_to_tensor(spectrum_train_scaled)

In [343]:
input_data

tensor([[350.6424, 357.4240, 338.0000,  ..., 267.9975, 252.9999, 238.0000],
        [352.1875, 360.3629, 342.0451,  ..., 269.5426, 252.9999, 236.4549],
        [350.6424, 357.4240, 338.0000,  ..., 267.9975, 252.9999, 238.0000],
        ...,
        [597.3622, 585.6775, 574.4026,  ..., 493.7609, 505.0429, 515.4084],
        [595.8171, 582.7385, 570.3576,  ..., 492.2158, 505.0429, 516.9536],
        [601.9975, 594.4943, 586.5379,  ..., 498.3961, 505.0429, 510.7732]])

In [344]:
output_data

tensor([[ 0.1119,  0.0821,  0.0643,  ...,  0.3437,  0.3360,  0.3290],
        [ 0.1240,  0.0904,  0.0701,  ...,  0.3269,  0.3197,  0.3131],
        [ 0.1119,  0.0821,  0.0643,  ...,  0.3437,  0.3360,  0.3290],
        ...,
        [10.8134, 10.6301, 10.5865,  ..., 21.8142, 20.7956, 19.8604],
        [10.6221, 10.5147, 10.5277,  ..., 22.0530, 21.0220, 20.0754],
        [11.4288, 11.0081, 10.7881,  ..., 21.1053, 20.1236, 19.2221]])

In [345]:
class Net(nn.Module):
    def __init__(self, input_size, output_size):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 512)
        self.fc4 = nn.Linear(512, 256)
        self.fc5 = nn.Linear(256, 128)
        self.fc6 = nn.Linear(128, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        x = torch.relu(self.fc5(x))
        x = self.fc6(x)
        return x

In [346]:
input_size = 11
output_size = 66
model = Net(input_size=input_size, output_size=output_size)

In [347]:
criterion = nn.MSELoss()  # Mean Squared Error loss for regression
optimizer = optim.Adam(model.parameters(), lr=0.00001)  # Adam optimizer

In [348]:
num_epochs = 15
batch_size = 1
num_batches = len(input_data) // batch_size

for epoch in range(num_epochs):
    for batch in range(num_batches):
        start = batch * batch_size
        end = start + batch_size

        inputs = input_data[start:end]
        targets = output_data[start:end]

        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Compute loss
        loss = criterion(outputs, targets)

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

    # Print the loss for this epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.6f}")

Epoch 1/15, Loss: 1.227836
Epoch 2/15, Loss: 1.090207
Epoch 3/15, Loss: 0.696722
Epoch 4/15, Loss: 0.454796
Epoch 5/15, Loss: 0.273489
Epoch 6/15, Loss: 0.237316
Epoch 7/15, Loss: 0.241296
Epoch 8/15, Loss: 0.252911
Epoch 9/15, Loss: 0.257222
Epoch 10/15, Loss: 0.254695
Epoch 11/15, Loss: 0.261782
Epoch 12/15, Loss: 0.263810
Epoch 13/15, Loss: 0.251714
Epoch 14/15, Loss: 0.279464
Epoch 15/15, Loss: 0.279204


In [349]:
predictions = model(df_to_tensor(temp_train))

In [350]:
pd.DataFrame(predictions.detach().numpy(), columns=spectrum_test.columns).head()

Unnamed: 0,0.000005,0.000005.1,0.000005.2,0.000005.3,0.000005.4,0.000005.5,0.000005.6,0.000005.7,0.000005.8,0.000005.9,...,0.000008,0.000008.1,0.000008.2,0.000008.3,0.000008.4,0.000008.5,0.000008.6,0.000008.7,0.000008.8,0.000008.9
0,2.790601,3.734772,3.551773,3.518033,3.052246,3.713205,3.674874,3.340815,3.95395,3.78427,...,7.757231,7.541261,7.631447,8.034568,7.89924,8.103604,7.596833,7.521339,6.925596,7.045845
1,2.660755,3.67539,3.459227,3.438999,2.937411,3.658012,3.575735,3.212343,3.859975,3.660512,...,7.546351,7.315151,7.375307,7.875371,7.722021,7.908733,7.400218,7.344554,6.72481,6.910841
2,2.790601,3.734772,3.551773,3.518033,3.052246,3.713205,3.674874,3.340815,3.95395,3.78427,...,7.757231,7.541261,7.631447,8.034568,7.89924,8.103604,7.596833,7.521339,6.925596,7.045845
3,2.341,3.51014,3.161989,3.232858,2.6323,3.531012,3.296155,2.866755,3.551401,3.293714,...,6.968811,6.701468,6.693804,7.43604,7.243237,7.32011,6.845436,6.842308,6.210735,6.525501
4,2.718299,3.626016,3.467902,3.460244,3.00514,3.585514,3.559272,3.253601,3.837587,3.687777,...,7.566105,7.376872,7.49767,7.827976,7.699324,7.931044,7.439087,7.325128,6.785043,6.879838


In [351]:
spectrum_test_scaled.head()

Unnamed: 0,0.000005,0.000005.1,0.000005.2,0.000005.3,0.000005.4,0.000005.5,0.000005.6,0.000005.7,0.000005.8,0.000005.9,...,0.000008,0.000008.1,0.000008.2,0.000008.3,0.000008.4,0.000008.5,0.000008.6,0.000008.7,0.000008.8,0.000008.9
0,2.373538,2.602481,2.775814,2.935397,3.074856,3.220465,3.363583,3.505908,3.647634,3.79009,...,10.06902,10.146292,10.208057,10.243674,10.246583,10.192531,10.036903,9.737349,9.326093,8.947765
1,2.027413,2.15003,2.250785,2.362463,2.467873,2.581041,2.6946,2.809533,2.925638,3.043688,...,8.794491,8.871859,8.934365,8.973221,8.983096,8.94286,8.813263,8.556884,8.201778,7.875021
2,2.47059,2.726463,2.921211,3.09597,3.246416,3.402639,3.555543,3.706972,3.857192,4.007671,...,10.460966,10.538142,10.59954,10.633985,10.634567,10.576088,10.412293,10.099323,9.670691,9.276407
3,1.950271,2.046041,2.131383,2.233903,2.333012,2.440313,2.548641,2.658761,2.770347,2.884035,...,8.542764,8.620148,8.682729,8.722118,8.733251,8.695631,8.57107,8.323129,7.979041,7.662406
4,2.523524,2.771409,2.957469,3.126776,3.273742,3.426844,3.576945,3.725913,3.87401,4.022686,...,10.45613,10.533535,10.595041,10.62953,10.630143,10.571716,10.408014,10.095198,9.666764,9.272662
