In [1]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import copy
import math
import numpy as np
import math
from data_prep import SystemIdentDataset, ControllerDataset, SystemIdentDatasetNormed
import pickle

def evaluate(model, loss_function, val_loader):
    c_error = 0.0
    cos_sim = torch.nn.CosineSimilarity(dim=1)
    running_loss = 0.0
    batch_count = 0
    for _, example in enumerate(tqdm(val_loader), 0):
        inputs,label = example

        with torch.no_grad():
            outputs = model(inputs)
        
        loss = loss_function(outputs, label)
        
        c_error += torch.sum(cos_sim(outputs,label)).item()/val_dloader.batch_size
        running_loss += loss.item()
        batch_count += 1
    
    c_error /= batch_count
    
    return running_loss, c_error

def train(model, num_epochs, loss_function, optimizer, train_loader, val_loader):
    best_loss = []
    best_cerror = []
    val_loss, c_error = evaluate(model, loss_function, val_loader)
    print(f"Initial validation loss: {val_loss}, cosine error: {c_error}")
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}:")
        for _, example in enumerate(tqdm(train_loader), 0):
            inputs,label = example
            
            outputs = model(inputs)

            loss = loss_function(outputs, label)

            optimizer.zero_grad()

            loss.backward()

            optimizer.step()
        val_loss, c_error = evaluate(model, loss_function, val_loader)
        print(f"validation loss: {val_loss}, cosine error: {c_error}")

        if epoch < 5:
            best_loss.append((copy.deepcopy(model), val_loss, c_error))
            best_cerror.append((copy.deepcopy(model), val_loss,c_error))
        else:
            for i, entry in enumerate(best_loss):
                m,l,c = entry
                if val_loss < l:
                    best_loss[i] = (copy.deepcopy(model), val_loss, c_error)
                    break
            for i, entry in enumerate(best_cerror):
                m,l,c = entry
                if c_error > c:
                    best_cerror[i] = (copy.deepcopy(model), val_loss, c_error)
                    break
    return best_loss + best_cerror


In [2]:
train_dataset = SystemIdentDatasetNormed(num_examples=1000000)
val_dataset = SystemIdentDatasetNormed(num_examples=100000)
train_dloader = DataLoader(train_dataset, batch_size=256,shuffle=True)
val_dloader = DataLoader(val_dataset, batch_size=256, shuffle=True)



In [3]:


# Define the model
model = torch.nn.Sequential(
    torch.nn.Linear(4, 32, dtype=torch.float64),
    torch.nn.Tanh(),
    torch.nn.Linear(32, 64, dtype=torch.float64),
    torch.nn.Tanh(),
    torch.nn.Linear(64, 64, dtype=torch.float64),
    torch.nn.Tanh(),
    torch.nn.Linear(64, 64, dtype=torch.float64),
    torch.nn.Tanh(),
    torch.nn.Linear(64, 32, dtype=torch.float64),
    torch.nn.Tanh(),
    torch.nn.Linear(32, 4, dtype=torch.float64),
)

# Define the loss function
#loss_fn = torch.nn.MSELoss()
csim = torch.nn.CosineSimilarity(dim=1)
loss_fn = lambda x,y: torch.sum(-1*csim(x,y))
# Define the optimizer
optimizer = torch.optim.Adam(model.parameters())
best_models = train(model=model, num_epochs=20, loss_function=loss_fn, optimizer=optimizer, train_loader=train_dloader, val_loader=val_dloader)
for i, model in enumerate(best_models):
    weights, loss, c_error = model
    print(f"Model {i}: loss: {loss}, cosine error: {c_error}")


100%|██████████| 391/391 [00:01<00:00, 208.04it/s]


Initial validation loss: -946.754009970265, cosine error: 0.00945845997812365
Epoch 1:


100%|██████████| 3907/3907 [00:25<00:00, 156.03it/s]
100%|██████████| 391/391 [00:02<00:00, 184.70it/s]


validation loss: -96698.08016414402, cosine error: 0.966053390386669
Epoch 2:


100%|██████████| 3907/3907 [00:25<00:00, 155.33it/s]
100%|██████████| 391/391 [00:02<00:00, 192.54it/s]


validation loss: -96090.62087548636, cosine error: 0.959984623516288
Epoch 3:


100%|██████████| 3907/3907 [00:26<00:00, 146.06it/s]
100%|██████████| 391/391 [00:01<00:00, 199.01it/s]


validation loss: -98243.74061657705, cosine error: 0.9814951708018008
Epoch 4:


100%|██████████| 3907/3907 [00:25<00:00, 156.14it/s]
100%|██████████| 391/391 [00:01<00:00, 204.42it/s]


validation loss: -97902.17448631114, cosine error: 0.9780827853891378
Epoch 5:


100%|██████████| 3907/3907 [00:25<00:00, 154.31it/s]
100%|██████████| 391/391 [00:01<00:00, 198.01it/s]


validation loss: -97907.72715864971, cosine error: 0.9781382588579934
Epoch 6:


100%|██████████| 3907/3907 [00:25<00:00, 154.11it/s]
100%|██████████| 391/391 [00:01<00:00, 204.20it/s]


validation loss: -98467.23127602843, cosine error: 0.9837279339436983
Epoch 7:


100%|██████████| 3907/3907 [00:25<00:00, 154.66it/s]
100%|██████████| 391/391 [00:02<00:00, 194.91it/s]


validation loss: -98477.64101322193, cosine error: 0.9838319314780004
Epoch 8:


100%|██████████| 3907/3907 [00:25<00:00, 154.92it/s]
100%|██████████| 391/391 [00:02<00:00, 194.70it/s]


validation loss: -99043.73479797032, cosine error: 0.9894874400372674
Epoch 9:


100%|██████████| 3907/3907 [00:25<00:00, 154.83it/s]
100%|██████████| 391/391 [00:01<00:00, 203.90it/s]


validation loss: -99208.08447709483, cosine error: 0.991129360584787
Epoch 10:


100%|██████████| 3907/3907 [00:25<00:00, 154.58it/s]
100%|██████████| 391/391 [00:01<00:00, 197.65it/s]


validation loss: -99165.92961373771, cosine error: 0.9907082162497773
Epoch 11:


100%|██████████| 3907/3907 [00:25<00:00, 153.83it/s]
100%|██████████| 391/391 [00:01<00:00, 204.12it/s]


validation loss: -98700.70440387349, cosine error: 0.9860604260297463
Epoch 12:


100%|██████████| 3907/3907 [00:25<00:00, 154.82it/s]
100%|██████████| 391/391 [00:01<00:00, 196.49it/s]


validation loss: -98760.79683630115, cosine error: 0.9866607740199523
Epoch 13:


100%|██████████| 3907/3907 [00:25<00:00, 154.86it/s]
100%|██████████| 391/391 [00:01<00:00, 204.28it/s]


validation loss: -98941.23593868376, cosine error: 0.9884634344897275
Epoch 14:


100%|██████████| 3907/3907 [00:25<00:00, 154.82it/s]
100%|██████████| 391/391 [00:02<00:00, 194.87it/s]


validation loss: -98395.13842768277, cosine error: 0.9830076968878154
Epoch 15:


100%|██████████| 3907/3907 [00:25<00:00, 155.26it/s]
100%|██████████| 391/391 [00:01<00:00, 196.26it/s]


validation loss: -99178.0149951576, cosine error: 0.9908289541555867
Epoch 16:


100%|██████████| 3907/3907 [00:25<00:00, 154.94it/s]
100%|██████████| 391/391 [00:01<00:00, 205.60it/s]


validation loss: -99588.793778839, cosine error: 0.9949328022981837
Epoch 17:


100%|██████████| 3907/3907 [00:25<00:00, 155.30it/s]
100%|██████████| 391/391 [00:01<00:00, 198.07it/s]


validation loss: -98745.75559887908, cosine error: 0.9865105059031238
Epoch 18:


100%|██████████| 3907/3907 [00:25<00:00, 155.87it/s]
100%|██████████| 391/391 [00:01<00:00, 206.32it/s]


validation loss: -98111.42514920284, cosine error: 0.9801732851382956
Epoch 19:


100%|██████████| 3907/3907 [00:25<00:00, 154.98it/s]
100%|██████████| 391/391 [00:01<00:00, 198.66it/s]


validation loss: -99005.1650634549, cosine error: 0.9891021126064469
Epoch 20:


100%|██████████| 3907/3907 [00:25<00:00, 156.08it/s]
100%|██████████| 391/391 [00:01<00:00, 203.32it/s]

validation loss: -98998.02799402179, cosine error: 0.9890308103622701
Model 0: loss: -99588.793778839, cosine error: 0.9949328022981837
Model 1: loss: -99178.0149951576, cosine error: 0.9908289541555867
Model 2: loss: -99005.1650634549, cosine error: 0.9891021126064469
Model 3: loss: -98998.02799402179, cosine error: 0.9890308103622701
Model 4: loss: -98111.42514920284, cosine error: 0.9801732851382956
Model 5: loss: -99588.793778839, cosine error: 0.9949328022981837
Model 6: loss: -99178.0149951576, cosine error: 0.9908289541555867
Model 7: loss: -99005.1650634549, cosine error: 0.9891021126064469
Model 8: loss: -98998.02799402179, cosine error: 0.9890308103622701
Model 9: loss: -98111.42514920284, cosine error: 0.9801732851382956





In [4]:
with open('./emulator_random_inputs3.pkl', 'wb') as f:
    pickle.dump(best_models[0][0], f)

In [36]:
# Now learn to drive emulated plant from state Zo to Zd in K steps where K is a hyperparameter

train_dataset = ControllerDataset(num_examples=1000000)
val_dataset = ControllerDataset(num_examples=100000)
train_dloader = DataLoader(train_dataset, batch_size=256,shuffle=True)
val_dloader = DataLoader(val_dataset, batch_size=256, shuffle=True)

class ControllerTrainedEnclosure(torch.nn.Module):
    def __init__(self, emulator_network, K):
        super(self.__class__, self).__init__()
        emulator_network.requires_grad=False 
        self.system_emulator = emulator_network
        self.K = K
        self.network = torch.nn.Sequential(
                torch.nn.Linear(4, 32, dtype=torch.float64),
                torch.nn.Tanh(),
                torch.nn.Linear(32, 64, dtype=torch.float64),
                torch.nn.Tanh(),
                torch.nn.Linear(64, 64, dtype=torch.float64),
                torch.nn.Tanh(),
                torch.nn.Linear(64, 32, dtype=torch.float64),
                torch.nn.Tanh(),
                torch.nn.Linear(32, 1, dtype=torch.float64),
                )

    def forward(self, x):
        for _ in range(self.K):
            xnew = torch.zeros_like(x)
            u = self.network(x)
            xnew[:,0] = u.squeeze()
            xnew[:,1:] = x[:,1:]

            dx = self.system_emulator(xnew)
            x = x + dx
            #torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1)
            
        
        return x



In [37]:

with open('./emulator_random_inputs2.pkl', 'rb') as f:
    emulator = pickle.load(f)
model = ControllerTrainedEnclosure(emulator, K=30)

# Define the loss function
loss_fn = torch.nn.MSELoss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters())
best_models = train(model=model, num_epochs=5, loss_function=loss_fn, optimizer=optimizer, train_loader=train_dloader, val_loader=val_dloader)
for i, model in enumerate(best_models):
    weights, loss, c_error = model
    print(f"Model {i}: loss: {loss}, cosine error: {c_error}")

100%|██████████| 79/79 [00:03<00:00, 23.22it/s]


Initial validation loss: 149962.52260807224, cosine error: 0.0004733375169052871
Epoch 1:


100%|██████████| 3907/3907 [06:51<00:00,  9.50it/s]
100%|██████████| 79/79 [00:04<00:00, 18.15it/s]


validation loss: 359981.7840931482, cosine error: 0.02704468563412509
Epoch 2:


100%|██████████| 3907/3907 [06:38<00:00,  9.80it/s]
100%|██████████| 79/79 [00:04<00:00, 19.18it/s]


validation loss: 358830.5046627023, cosine error: 0.026850415248318663
Epoch 3:


100%|██████████| 3907/3907 [06:19<00:00, 10.28it/s]
100%|██████████| 79/79 [00:04<00:00, 19.17it/s]


validation loss: 358609.25255553244, cosine error: 0.02687181946608139
Epoch 4:


100%|██████████| 3907/3907 [06:20<00:00, 10.26it/s]
100%|██████████| 79/79 [00:04<00:00, 19.12it/s]


validation loss: 353410.2904779326, cosine error: 0.02665515366050358
Epoch 5:


100%|██████████| 3907/3907 [06:21<00:00, 10.24it/s]
100%|██████████| 79/79 [00:04<00:00, 18.97it/s]

validation loss: 337965.3851633391, cosine error: 0.026950246391290223
Model 0: loss: 359981.7840931482, cosine error: 0.02704468563412509
Model 1: loss: 358830.5046627023, cosine error: 0.026850415248318663
Model 2: loss: 358609.25255553244, cosine error: 0.02687181946608139
Model 3: loss: 353410.2904779326, cosine error: 0.02665515366050358
Model 4: loss: 337965.3851633391, cosine error: 0.026950246391290223
Model 5: loss: 359981.7840931482, cosine error: 0.02704468563412509
Model 6: loss: 358830.5046627023, cosine error: 0.026850415248318663
Model 7: loss: 358609.25255553244, cosine error: 0.02687181946608139
Model 8: loss: 353410.2904779326, cosine error: 0.02665515366050358
Model 9: loss: 337965.3851633391, cosine error: 0.026950246391290223





In [None]:
best_models[8][0].network

In [None]:


a = torch.zeros((1,4),dtype=torch.float64)
#a[0][2] = np.pi + np.pi/20
em = best_models[8][0].system_emulator
nn = best_models[8][0].network

em(a)


In [None]:
model

In [None]:

train_dataset = SystemIdentDatasetEuler(num_examples=2000000)
val_dataset = SystemIdentDatasetEuler(num_examples=100000)
train_dloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_dloader = DataLoader(val_dataset, batch_size=256, shuffle=True)

model = torch.nn.Sequential(
    torch.nn.Linear(4, 32, dtype=torch.float64),
    torch.nn.Tanh(),
    torch.nn.Linear(32, 64, dtype=torch.float64),
    torch.nn.Tanh(),
    torch.nn.Linear(64, 64, dtype=torch.float64),
    torch.nn.Tanh(),
    torch.nn.Linear(64, 64, dtype=torch.float64),
    torch.nn.Tanh(),
    torch.nn.Linear(64, 32, dtype=torch.float64),
    torch.nn.Tanh(),
    torch.nn.Linear(32, 4, dtype=torch.float64),
)

# Define the loss function
loss_fn = torch.nn.MSELoss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters())
best_models = train(model=model, num_epochs=25, loss_function=loss_fn, optimizer=optimizer, train_loader=train_dloader, val_loader=val_dloader)
for i, model in enumerate(best_models):
    weights, loss, c_error = model
    print(f"Model {i}: loss: {loss}, cosine error: {c_error}")