In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_

# Define the model
class PredictionModel(nn.Module):
    def __init__(self):
        super(PredictionModel, self).__init__()
        self.fc1 = nn.Linear(2, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 3)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Create an instance of the model and move it to the GPU
model = PredictionModel().cuda()

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)

# Function to calculate x_cur and y_cur based on alpha, beta, gamma
def calculate_x_cur_y_cur(alpha, beta, gamma):
    L1 = 12.8
    L2 = 12.5
    L3 = 24
    R1 = 2.5
    R2 = 3
    t = 21.7

    alpha = np.radians(d2angle(alpha))
    beta = np.radians(d2angle(beta))
    gamma = np.radians(d2angle(gamma))

    x_cur = L1 * np.sin(alpha) + L2 * np.cos(alpha + beta) - L3 * np.sin(alpha + beta + gamma) - R1 * np.cos(alpha) - R2 * np.cos(alpha + beta + gamma)
    y_cur = t - L1 * np.cos(alpha) + L2 * np.sin(alpha + beta) + L3 * np.cos(alpha + beta + gamma) - R1 * np.sin(alpha) - R2 * np.sin(alpha + beta + gamma)

    return x_cur, y_cur

def angle2d(angle) :
    return int(4096 * (angle/360))
def d2angle(d) :
    return int(360 * (d/4096))

# Generate random alpha, beta, gamma values and calculate x_cur and y_cur
def generate_data():
    alpha = np.random.randint(angle2d(90), angle2d(181))
    beta = np.random.randint(angle2d(90), angle2d(271))
    gamma = np.random.randint(angle2d(90), angle2d(181))
    x_cur, y_cur = calculate_x_cur_y_cur(alpha, beta, gamma)
    return [x_cur, y_cur], [alpha, beta, gamma]

# Generate training data
def generate_training_data(num_samples):
    X_train = []
    y_train = []

    for _ in range(num_samples):
        while True :
            x, y = generate_data()
            if 0 < x[0] and x[0] < 40 and 0 < x[1] and x[1] < 40 :
                X_train.append(x)
                y_train.append(y)
                break

    X_train = torch.tensor(X_train, dtype=torch.float32).cuda()
    y_train = torch.tensor(y_train, dtype=torch.float32).cuda()

    return X_train, y_train

# Train the model
def train_model(model, num_epochs, batch_size):
    num_samples = num_epochs * batch_size
    X_train, y_train = generate_training_data(num_samples)
    for epoch in tqdm(range(num_epochs)):
        epoch_loss = 0.0

        for i in range(0, num_samples, batch_size):
            inputs = X_train[i:i + batch_size]
            labels = y_train[i:i + batch_size]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            # Apply gradient clipping
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += loss.item()

        # Update the learning rate
        scheduler.step()
        #print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")

# Set random seed for reproducibility
torch.manual_seed(42)

def test_model(model, a, b) :
    L1 = 12.8
    L2 = 12.5
    L3 = 24
    R1 = 2.5
    R2 = 3
    t = 21.7
    data = torch.tensor([[a, b]], dtype=torch.float32).cuda()
    output = model(data)
    print(output[0][0].item(), output[0][1].item(), output[0][2].item())
    print(d2angle(output[0][0].item()), d2angle(output[0][1].item()), d2angle(output[0][2].item()))
    alpha = np.radians(d2angle(output[0][0].item()))
    beta = np.radians(d2angle(output[0][1].item()))
    gamma = np.radians(d2angle(output[0][2].item()))
    x_cur = L1 * np.sin(alpha) + L2 * np.cos(alpha + beta) - L3 * np.sin(alpha + beta + gamma) - R1 * np.cos(alpha) - R2 * np.cos(alpha + beta + gamma)
    y_cur = t - L1 * np.cos(alpha) + L2 * np.sin(alpha + beta) + L3 * np.cos(alpha + beta + gamma) - R1 * np.sin(alpha) - R2 * np.sin(alpha + beta + gamma)
    print("정답 : ", a, b)
    print("예측 : ", x_cur, y_cur)
#model.load_state_dict(torch.load("abg_prediction_model_5.pth"))
# Train the model for 1000 epochs with a batch size of 32
for i in range(30) :
    torch.manual_seed(42+i)
    train_model(model, num_epochs=1000, batch_size=64)
    torch.save(model.state_dict(), f"abg_prediction_model_{i}.pth")
    print("save :", f"abg_prediction_model_{i}.pth")
    test_model(model, 10, 15)
    test_model(model, 20, 20)
    test_model(model, 25, 10)

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:28<00:00,  1.15it/s]


save : abg_prediction_model_0.pth
1768.700439453125 2646.433349609375 1440.6151123046875
155 232 126
정답 :  10 15
예측 :  10.590112349043165 15.172947186722242
1921.2523193359375 2774.8525390625 1579.2581787109375
168 243 138
정답 :  20 20
예측 :  19.690635715175734 20.17961780139782
1606.5263671875 2690.086669921875 1842.3939208984375
141 236 161
정답 :  25 10
예측 :  24.112559918815595 9.638735299493055


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:28<00:00,  1.15it/s]


save : abg_prediction_model_1.pth
1757.29833984375 2669.382080078125 1433.77587890625
154 234 126
정답 :  10 15
예측 :  11.070455321344324 15.090859908933037
1903.17041015625 2786.642333984375 1562.919677734375
167 244 137
정답 :  20 20
예측 :  19.492762174953818 19.974969364685982
1602.380615234375 2716.634765625 1837.291259765625
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:33<00:00,  1.14it/s]


save : abg_prediction_model_2.pth
1762.95361328125 2669.26220703125 1434.6883544921875
154 234 126
정답 :  10 15
예측 :  11.070455321344324 15.090859908933037
1902.580322265625 2787.70458984375 1564.6015625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1604.133544921875 2716.48095703125 1833.1097412109375
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:36<00:00,  1.14it/s]


save : abg_prediction_model_3.pth
1760.08154296875 2674.68505859375 1430.89990234375
154 235 125
정답 :  10 15
예측 :  10.966356749850183 15.28258562718862
1901.2940673828125 2793.63037109375 1561.838134765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1603.0048828125 2720.36083984375 1835.114013671875
140 239 161
정답 :  25 10
예측 :  24.961774706776602 9.968001778421023


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:38<00:00,  1.14it/s]


save : abg_prediction_model_4.pth
1759.9908447265625 2673.78759765625 1432.0635986328125
154 235 125
정답 :  10 15
예측 :  10.966356749850183 15.28258562718862
1901.8486328125 2790.266357421875 1562.374755859375
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1603.4754638671875 2719.701416015625 1835.657958984375
140 239 161
정답 :  25 10
예측 :  24.961774706776602 9.968001778421023


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:38<00:00,  1.14it/s]


save : abg_prediction_model_5.pth
1759.564697265625 2673.413818359375 1431.61376953125
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9754638671875 2789.515625 1561.9022216796875
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.831787109375 2719.365478515625 1835.21533203125
140 239 161
정답 :  25 10
예측 :  24.961774706776602 9.968001778421023


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:38<00:00,  1.14it/s]


save : abg_prediction_model_6.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:36<00:00,  1.14it/s]


save : abg_prediction_model_7.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:34<00:00,  1.14it/s]


save : abg_prediction_model_8.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:33<00:00,  1.15it/s]


save : abg_prediction_model_9.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:38<00:00,  1.14it/s]


save : abg_prediction_model_10.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:32<00:00,  1.15it/s]


save : abg_prediction_model_11.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:33<00:00,  1.14it/s]


save : abg_prediction_model_12.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:36<00:00,  1.14it/s]


save : abg_prediction_model_13.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:35<00:00,  1.14it/s]


save : abg_prediction_model_14.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:36<00:00,  1.14it/s]


save : abg_prediction_model_15.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:35<00:00,  1.14it/s]


save : abg_prediction_model_16.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:32<00:00,  1.15it/s]


save : abg_prediction_model_17.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:33<00:00,  1.15it/s]


save : abg_prediction_model_18.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:34<00:00,  1.14it/s]


save : abg_prediction_model_19.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:36<00:00,  1.14it/s]


save : abg_prediction_model_20.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:35<00:00,  1.14it/s]


save : abg_prediction_model_21.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:33<00:00,  1.15it/s]


save : abg_prediction_model_22.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:35<00:00,  1.14it/s]


save : abg_prediction_model_23.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:35<00:00,  1.14it/s]


save : abg_prediction_model_24.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:35<00:00,  1.14it/s]


save : abg_prediction_model_25.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:37<00:00,  1.14it/s]


save : abg_prediction_model_26.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:36<00:00,  1.14it/s]


save : abg_prediction_model_27.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:37<00:00,  1.14it/s]


save : abg_prediction_model_28.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:35<00:00,  1.14it/s]

save : abg_prediction_model_29.pth
1759.548583984375 2673.42236328125 1431.607666015625
154 234 125
정답 :  10 15
예측 :  10.672228284200688 15.2309023807408
1900.9273681640625 2789.5009765625 1561.884765625
167 245 137
정답 :  20 20
예측 :  19.728559281886255 20.224476834312764
1602.7431640625 2719.285400390625 1835.168701171875
140 238 161
정답 :  25 10
예측 :  24.611684296448956 9.712410374328167





In [13]:
def test_model(model, a, b) :
    L1 = 13
    L2 = 12.5
    L3 = 25
    R1 = 2.5
    R2 = 3.5
    t = 20.9
    data = torch.tensor([[a, b]], dtype=torch.float32).cuda()
    output = model(data)
    print(output[0][0].item(), output[0][1].item(), output[0][2].item())
    alpha = np.radians(output[0][0].item())
    beta = np.radians(output[0][1].item())
    gamma = np.radians(output[0][2].item())
    #alpha = output[0][0].item()
    #beta = output[0][1].item()
    #gamma = output[0][2].item()
    x_cur = L1 * np.sin(alpha) + L2 * np.cos(alpha + beta) + L3 * np.sin(alpha + beta + gamma)
    y_cur = t - L1 * np.cos(alpha) - L2 * np.sin(alpha + beta) + L3 * np.cos(alpha + beta + gamma) - R1 * np.sin(alpha) - R2 * np.cos(alpha + beta + gamma)
    print("정답 : ", a, b)
    print("예측 : ", x_cur, y_cur)

test_model(model, 10, 15)
test_model(model, 20, 20)
test_model(model, 25, 10)

96.75641632080078 184.5462188720703 243.52655029296875
정답 :  10 15
예측 :  21.90204885758662 11.453640865757466
92.53882598876953 148.8344268798828 234.44064331054688
정답 :  20 20
예측 :  29.50379842180012 20.588136392555022
82.17068481445312 197.24188232421875 222.6363067626953
정답 :  25 10
예측 :  30.2978236783575 12.030600072569525
