In [1]:
import torch
from config import get_config, get_weights_file_path
from train import get_model, get_ds, get_normalize_dataset
from dataset import causal_mask, LSAMDataset
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader

In [2]:
device = torch.device(
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
print(f'Using device {device}')
config = get_config()

Using device mps


In [3]:
train_dataloader, val_dataloader, tokenizer_src = get_ds(config)

In [4]:
model = get_model(config, tokenizer_src.get_vocab_size()).to(device)

In [7]:
model_filename = get_weights_file_path(config, f'499')

In [8]:
state = torch.load(model_filename, map_location=torch.device('mps'))
model.load_state_dict(state['model_state_dict'])

<All keys matched successfully>

In [20]:
def greedy_decode(model, source, source_mask, max_len, device):
    decoder_sos = torch.full((120, ), 1.0).to(device)
    decoder_eos = torch.full((120, ), 1.0).to(device)
    
    encoder_output = model.encode(source, source_mask)
    decoder_input = decoder_sos.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, 120)
    while True:
        print(decoder_input.size(1))
        if decoder_input.size(1) == max_len:
            break
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
        # calculate the output of the decoder
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
        output = model.project(out[:, -1]).unsqueeze(0).to(device)
        decoder_input = torch.cat([decoder_input, output], dim=1)

        if torch.all(output[0] == decoder_eos):
            break
    return decoder_input.squeeze(0)


def run_validation(model, validation_ds, max_len, device):
    model.eval()

    # batch size = 1
    with torch.no_grad():
        for batch in validation_ds:
            encoder_input = batch['encoder_input'].to(device)  # (b, seq_len)
            encoder_mask = batch['encoder_mask'].to(device)  # (b, 1, 1, seq_len)
            assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

            model_output = greedy_decode(model, encoder_input, encoder_mask, max_len, device)

            return model_output

In [14]:
def min_max_normalize(two_d_list, min_value, max_value):
    """Helper function"""
    normalized_2d_list = [
        [(item - min_value) / (max_value - min_value) if max_value - min_value else 0 for item in sublist]
        for sublist in two_d_list
    ]
    return normalized_2d_list

def get_dataset(config):
    """ get dataset from csv file and normlize it
     key: source coordinates value: min-max normalized list"""
    path_list, num_layers_list = config["path_list"], config["num_layers_list"]
    max_value = -np.infty
    min_value = np.infty
    dataset = {}

   
    for layer in range(8, 9):
        value_list = []
        # csv_file_path = path_list[0] + f'layer_{layer}.csv'
        csv_file_path = './layer_8.csv'
        df = pd.read_csv(csv_file_path)
        df_max = df.values.max()
        df_min = df.values.min()
        if df_max > max_value:
            max_value = df_max
        if df_min < min_value:
            min_value = df_min

        coordinates = df.columns.tolist()
        key = ' '.join(coordinates)
        for column in df:
            column_list = df[column].tolist()
            value_list.append(column_list)
        dataset[key] = value_list

    normalized_dataset = {key: min_max_normalize(value, min_value, max_value) for key, value in dataset.items()}

    return normalized_dataset

In [15]:
normal_dataset = get_dataset(config)

In [16]:
normal_dataset

{'152_65 152_66 152_71 152_72 152_63 152_67 152_69 153_69 152_58 152_61 152_57 152_60 153_57 153_60 152_49 152_51 152_54 152_45 153_64 153_66 154_43 152_43 154_42 152_41 154_40 154_38 154_36 154_35 212_5 213_6 212_2 209_2 211_2 213_3 205_2 202_2 217_9 218_14 218_15 219_13 184_8 197_2 199_4 201_2 201_3 214_7 215_9 216_9 217_14 219_14 220_18 220_21 220_25 220_27 218_17 218_18 218_21 218_24 218_28 218_30 220_34 218_47 218_48 218_51 219_40 219_43 219_47 219_48 219_51 217_44 217_45 217_49 217_53 218_33 219_39 219_52 218_54 218_58 218_62 218_63 219_57 217_54 217_57 217_60 219_59 218_66 218_69 218_72 218_75 216_72 216_75 217_66 217_69 218_79 218_83 216_78 216_81 218_85 218_87 218_90 217_89 217_90 218_84 216_85 216_89 216_90 218_93 217_93 216_93 217_121 218_98 218_99 217_96 217_99 217_103 217_105 217_108 217_111 216_96 208_140 211_137 212_134 215_127 207_140 209_137 210_136 212_131 212_133 213_131 215_121 215_125 215_126 216_116 216_120 207_139 208_137 210_134 210_135 212_130 213_127 213_129 2

In [17]:
val_ds = LSAMDataset(normal_dataset, tokenizer_src, config['seq_len'])
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=False)

In [18]:
val_dataloader.dataset[0]['decoder_input'][1:10]

tensor([[0.9637, 0.9453, 0.9230,  ..., 0.0918, 0.0918, 0.0851],
        [0.9580, 0.9415, 0.9211,  ..., 0.0874, 0.0866, 0.0777],
        [0.9283, 0.9111, 0.8919,  ..., 0.0814, 0.0800, 0.0777],
        ...,
        [0.9444, 0.9278, 0.9063,  ..., 0.0837, 0.0814, 0.0814],
        [0.8026, 0.7871, 0.7684,  ..., 0.0792, 0.0814, 0.0837],
        [0.9627, 0.9392, 0.9240,  ..., 0.0925, 0.0829, 0.0866]])

In [21]:
a = run_validation(model, val_dataloader, 744, device)
# a_original = np.loadtxt('mar15.txt')

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277


In [22]:
def get_normalize_dataset(config):
    """ get dataset from csv file and normlize it
     key: source coordinates value: min-max normalized list"""
    path_list, num_layers_list = config["path_list"], config["num_layers_list"]
    max_value = -np.infty
    min_value = np.infty
    dataset = {}

    for path, num_layers in zip(path_list, num_layers_list):
        for layer in range(1, num_layers + 1):
            value_list = []
            csv_file_path = path + f'layer_{layer}.csv'
            df = pd.read_csv(csv_file_path)
            df_max = df.values.max()
            df_min = df.values.min()
            if df_max > max_value:
                max_value = df_max
            if df_min < min_value:
                min_value = df_min

            coordinates = df.columns.tolist()
            key = ' '.join(coordinates)
            for column in df:
                column_list = df[column].tolist()
                value_list.append(column_list)
            dataset[key] = value_list

    normalized_dataset = {key: min_max_normalize(value, min_value, max_value) for key, value in dataset.items()}

    return max_value, min_value

In [23]:
get_normalize_dataset(config)

(207.11447391789065, 66.93675850248557)

In [24]:
X_min, X_max = 66.93675850248557, 207.11447391789065

In [25]:
a_original = a * (X_max - X_min) + X_min

In [26]:
a_original = a_original.cpu().numpy()
a_original

array([[207.11447 , 207.11447 , 207.11447 , ..., 207.11447 , 207.11447 ,
        207.11447 ],
       [182.94672 , 181.09247 , 178.203   , ...,  91.15544 ,  96.66438 ,
         89.12458 ],
       [183.67932 , 175.38388 , 176.68503 , ...,  96.92519 ,  98.77238 ,
         98.604485],
       ...,
       [183.91963 , 184.29599 , 181.56506 , ..., 101.294525,  98.807526,
         97.207   ],
       [184.17252 , 184.56865 , 180.33002 , ..., 100.86955 ,  98.32285 ,
         98.495735],
       [184.69218 , 185.172   , 178.41165 , ..., 101.530106,  98.12594 ,
         99.67575 ]], dtype=float32)

In [27]:
import matplotlib.pyplot as plt
import os

In [28]:
csv_path = './layer_8/'
all_files = os.listdir(csv_path)
csv_files = [file for file in all_files if file.endswith('.csv')]
sorted_csv_files = sorted(csv_files, key=lambda x: int(x.split('_')[0]))

In [29]:
ground_truth = []
ground_truth_filename = []
for file in sorted_csv_files:
    ground_truth_filename.append(file)
    file_path = os.path.join(csv_path, file)
    df = pd.read_csv(file_path, header=None)[0].tolist()
    ground_truth.append(df[:120])

In [30]:
ground_truth = np.array(ground_truth)

In [31]:
ground_truth.shape

(743, 120)

In [33]:
a_original.shape
np.savetxt('table_totems_for_planters_layer_8.txt', a_original)

In [42]:
x = np.arange(1, 121)
for i in range(len(ground_truth)):
    squared_differences = (a_original[i+1] - ground_truth[i]) ** 2
    mse = np.mean(squared_differences)
    rounded_mse = round(mse, 2)
    plt.scatter(x, a_original[i+1], label='Predictions', marker='o')
    plt.scatter(x, ground_truth[i], label='Observations')
    plt.xlim(0, 125)
    plt.ylim(75, 200)
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    new_xticks = np.linspace(0, 120, num=7)  # 选择几个点作为刻度点
    new_labels = np.linspace(0, 240, num=7)  # 对应的刻度标签从1到240
    plt.xlabel('Time(s)', fontsize=16)
    plt.ylabel('Surface temperature(°C)', fontsize=16)
    plt.xticks(ticks=new_xticks, labels=new_labels.astype(int), fontsize=16)
    plt.yticks(fontsize=16)
    plt.legend(fancybox=True, framealpha=1, shadow=True, borderpad=0.5, 
               fontsize=16, loc='upper right',  borderaxespad=0.2)
    name, ext = os.path.splitext(ground_truth_filename[i])
    output_filename = name + f"_{rounded_mse}" + ".png"
    plt.savefig(f'./planters_validate_image/{output_filename}')
    # plt.show()
    plt.clf()

<Figure size 640x480 with 0 Axes>

In [43]:
def sort_key(filename):
    num_part = filename.split('_')[-1].replace('.png', '')
    return float(num_part)

In [44]:
image_path = './planters_validate_image/'
filenames = []
for file in os.listdir(image_path):
    if file.endswith('.png'):
        filenames.append(file)

In [45]:
sorted_filenames = sorted(filenames, key=sort_key)
for filename in sorted_filenames:
    print(filename)

386_(152, 93)_8.08.png
387_(152, 96)_10.38.png
293_(197, 143)_12.01.png
379_(153, 134)_13.86.png
388_(152, 99)_15.33.png
572_(123, 146)_16.86.png
245_(219, 54)_17.13.png
244_(219, 36)_19.03.png
237_(219, 17)_19.04.png
23_(153, 64)_19.23.png
580_(122, 146)_19.25.png
638_(134, 29)_19.77.png
24_(153, 66)_21.22.png
568_(108, 147)_21.68.png
252_(216, 123)_21.8.png
584_(135, 139)_22.19.png
233_(182, 8)_22.65.png
576_(106, 147)_22.94.png
635_(140, 54)_23.73.png
641_(126, 31)_24.44.png
380_(153, 135)_24.81.png
243_(219, 33)_24.93.png
230_(215, 6)_26.2.png
570_(114, 147)_26.97.png
713_(120, 33)_27.78.png
246_(219, 60)_27.8.png
575_(104, 147)_28.16.png
569_(113, 147)_28.2.png
634_(140, 48)_28.44.png
356_(165, 144)_29.24.png
242_(219, 30)_29.79.png
238_(219, 20)_30.43.png
231_(167, 14)_30.68.png
232_(178, 11)_30.74.png
389_(152, 102)_30.92.png
571_(117, 147)_31.99.png
236_(216, 8)_32.02.png
234_(191, 5)_32.57.png
577_(123, 148)_32.93.png
228_(185, 8)_33.7.png
579_(108, 148)_34.39.png
633_(137, 33

In [39]:
def find_indexes_with_prefix(lst, prefix):
    indexes = [i for i, item in enumerate(lst) if item.startswith(prefix)]
    return indexes

In [40]:
index = find_indexes_with_prefix(ground_truth_filename, '394_')
index

[371]

In [None]:
index = find_indexes_with_prefix(ground_truth_filename, '467_')
index

In [None]:
index = find_indexes_with_prefix(ground_truth_filename, '518_')
index

In [None]:
plt.scatter(x, ground_truth[371], label='Position 1 Observation')
plt.scatter(x, ground_truth[471], label='Position 2 Observation')
plt.scatter(x, ground_truth[421], label='Position 3 Observation')
plt.xlim(0, 125)
plt.ylim(75, 190)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
new_xticks = np.linspace(0, 120, num=7) 
new_labels = np.linspace(0, 240, num=7) 
plt.xlabel('Time(s)', fontsize=16)
plt.ylabel('Surface temperature(°C)', fontsize=16)
plt.xticks(ticks=new_xticks, labels=new_labels.astype(int), fontsize=16)
plt.yticks(fontsize=16)
plt.legend(fancybox=True, framealpha=1, shadow=True, borderpad=0.5, 
            fontsize=16, loc='upper right',  borderaxespad=0.2)
plt.tight_layout()
plt.savefig('./paper_figures/3_ground_truth_profiles.png')
plt.show()

In [None]:
plt.scatter(x, a_original[372], label = 'Position 1 Prediction')
plt.scatter(x, ground_truth[371], label='Position 1 Observation')
# plt.scatter(x, ground_truth[421], label='Position 3')
# plt.scatter(x, ground_truth[471], label='Position 2')
plt.xlim(0, 125)
plt.ylim(75, 190)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
new_xticks = np.linspace(0, 120, num=7) 
new_labels = np.linspace(0, 240, num=7) 
plt.xlabel('Time(s)', fontsize=16)
plt.ylabel('Surface temperature(°C)', fontsize=16)
plt.xticks(ticks=new_xticks, labels=new_labels.astype(int), fontsize=16)
plt.yticks(fontsize=16)
plt.legend(fancybox=True, framealpha=1, shadow=True, borderpad=0.5, 
            fontsize=16, loc='upper right',  borderaxespad=0.2)
plt.tight_layout()
# plt.savefig('./paper_figures/position1_compare.png')
plt.show()

In [None]:
plt.scatter(x, a_original[422], label = 'Position 2 Prediction')
plt.scatter(x, ground_truth[421], label='Position 2 Observation')
plt.xlim(0, 125)
plt.ylim(75, 190)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
new_xticks = np.linspace(0, 120, num=7) 
new_labels = np.linspace(0, 240, num=7) 
plt.xlabel('Time(s)', fontsize=16)
plt.ylabel('Surface temperature(°C)', fontsize=16)
plt.xticks(ticks=new_xticks, labels=new_labels.astype(int), fontsize=16)
plt.yticks(fontsize=16)
plt.legend(fancybox=True, framealpha=1, shadow=True, borderpad=0.5, 
            fontsize=16, loc='upper right',  borderaxespad=0.2)
plt.savefig('./paper_figures/point2_profiles.png')
plt.show()

In [None]:
plt.scatter(x, a_original[472], label = 'Position 2 Prediction')
plt.scatter(x, ground_truth[471], label='Position 2 Observation')
plt.xlim(0, 125)
plt.ylim(75, 190)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
new_xticks = np.linspace(0, 120, num=7) 
new_labels = np.linspace(0, 240, num=7) 
plt.xlabel('Time(s)', fontsize=16)
plt.ylabel('Surface temperature(°C)', fontsize=16)
plt.xticks(ticks=new_xticks, labels=new_labels.astype(int), fontsize=16)
plt.yticks(fontsize=16)
plt.legend(fancybox=True, framealpha=1, shadow=True, borderpad=0.5, 
            fontsize=16, loc='upper right',  borderaxespad=0.2)
plt.savefig('./paper_figures/point2_profiles.png')
plt.show()

In [None]:
plt.scatter(x, a_original[372], label='Position 1 Prediction')
plt.scatter(x, a_original[472], label='Position 2 Prediction')
plt.scatter(x, a_original[422], label='Position 3 Prediction')
plt.xlim(0, 125)
plt.ylim(75, 190)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
new_xticks = np.linspace(0, 120, num=7) 
new_labels = np.linspace(0, 240, num=7) 
plt.xlabel('Time(s)', fontsize=16)
plt.ylabel('Surface temperature(°C)', fontsize=16)
plt.xticks(ticks=new_xticks, labels=new_labels.astype(int), fontsize=16)
plt.yticks(fontsize=16)
plt.legend(fancybox=True, framealpha=1, shadow=True, borderpad=0.5, 
            fontsize=16, loc='upper right',  borderaxespad=0.2)
plt.tight_layout()
plt.savefig('./paper_figures/prediction_compare.png')
plt.show()

In [None]:
x = np.arange(1, 121)
all_mse = 0
all_rmse = 0
for i in range(len(ground_truth)):
    squared_differences = (a_original[i+1] - ground_truth[i]) ** 2
    mse = np.mean(squared_differences)
    rounded_mse = round(mse, 2)
    all_mse += rounded_mse/120.0
    rmse = np.sqrt(rounded_mse)
    all_rmse += rmse/120.0

In [None]:
average_mse = all_mse/len(ground_truth)
average_rmse = all_rmse/len(ground_truth)

In [None]:
print(average_mse)
print(average_rmse)