# Triplet Train

使用ipynb笔记本来对python代码进行逐步调试处理/训练

首先，读取数据。

In [1]:
import logging
import torch
import numpy as np
import tritrain
import dataset
import torch.optim as optim

from dissimilarity import calculate_cs_dissimilarity_matrix
from dataset import read_all_of_huaweicup,read_npy_of_file,npy_file_name_converter

from geodesic import construct_knn_graph,compute_shortest_paths_dijkstra,compute_shortest_paths_dijkstra_cugraph

# 尝试导入 cugraph 和 cudf，如果失败则使用 networkx
try:
    import cudf
    import cugraph
    GPU_AVAILABLE = True
except ImportError:
    GPU_AVAILABLE = False

## 2. 数据读入及预处理
使用 huawei 杯官方给的示例代码，导入数据。

准备数据集，首先定义有关函数。

需要注意，我们的数据集每一个对应着不同的扇区，这里先导一个，每个文件导出的时候也对应不同扇区，训练三个模型。

In [2]:
print("<<< Welcome to 2024 Wireless Algorithm Contest! >>>\n")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Round_idx = 1
File_idx = 1
bs_pos, tol_samp_num, anch_samp_num, port_num, ant_num, sc_num, anch_pos, H, d_geo = read_all_of_huaweicup(Round_idx,File_idx)

# 初始化日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()


<<< Welcome to 2024 Wireless Algorithm Contest! >>>

Processing Round 1 Case 1
Loading configuration data file
Loading input position file
Loading input CSI data of Case 1
Loading Channel CSI succeed
Loading GEO data succeed
[[   0.         2339.73608398 2482.65429688 ... 2290.90087891
  3095.56591797 1301.00170898]
 [2339.73608398    0.         1570.17565918 ... 1972.58605957
  1890.33178711 2245.20751953]
 [2482.65429688 1570.17565918    0.         ... 1390.14709473
  2452.04980469 2113.81030273]
 [2267.93457031  932.56329346 1780.73571777 ... 2248.97143555
  1986.77770996 2143.52075195]
 [3154.51660156 2009.21484375 2895.19995117 ... 2991.2109375
  1568.25439453 3070.37939453]]
(20000, 20000)


## 3.网络训练

Triplet Neural Networks

![Triplet Neural Networks](/home/xmax/Desktop/huaweicup/pic/img.png)

### 3.1 Triplet Selection

输入为Dgeo矩阵，为两个采样点之间的接地线距离，输出为 f(i) , f(j) , f(k) ，其中 i 为2000个锚点， j 为反例，距离 i 很远的例子， k 为正例距离 i 很近的例子。
有一个参数Q，用来衡量近的点有多近，远的点有多远，是一个可调参数。
论文中，Q=0.02-0.2，此处为可调参数。

### 3.2 Embedded 网络


### 3.3 Triplet Loss


In [3]:
# 创建模型
embed_net = tritrain.TripletNet().to(device)

# 初始化损失函数，给入参数 margin
margin = 1.0
criterion = tritrain.TripletLoss(margin).to(device)

H_real = H.real
H_imag = H.imag
H_combined = np.stack((H_real, H_imag), axis=2)
H_combined = H_combined.reshape(tol_samp_num, 4, ant_num, sc_num)

optimizer = optim.Adam(embed_net.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# 训练网络
tritrain.train_triplet_network(embed_net, criterion, optimizer, scheduler , H_combined, d_geo, init_q=0.1)

# 保存网络实例
torch_net_file = dataset.tensor_file_name_converter(1,1)
torch.save(embed_net.state_dict(), torch_net_file)
print("Model saved as "+ torch_net_file)


使用cuda进行计算


Epoch 1/20:  16%|█▋        | 103/625 [00:07<00:32, 15.89it/s]

Epoch [1/20], Batch [100], Q: 0.3000, Loss: 18.1153


Epoch 1/20:  32%|███▏      | 203/625 [00:13<00:26, 15.98it/s]

Epoch [1/20], Batch [200], Q: 0.3000, Loss: 0.7521


Epoch 1/20:  48%|████▊     | 303/625 [00:19<00:20, 15.99it/s]

Epoch [1/20], Batch [300], Q: 0.3000, Loss: 0.6855


Epoch 1/20:  64%|██████▍   | 403/625 [00:25<00:13, 16.65it/s]

Epoch [1/20], Batch [400], Q: 0.3000, Loss: 0.6173


Epoch 1/20:  80%|████████  | 503/625 [00:32<00:08, 13.66it/s]

Epoch [1/20], Batch [500], Q: 0.3000, Loss: 0.6153


Epoch 1/20:  96%|█████████▋| 602/625 [00:40<00:01, 14.05it/s]

Epoch [1/20], Batch [600], Q: 0.3000, Loss: 0.5715


                                                             

Epoch [1/20] Completed, Learning Rate: 0.001000


Epoch 2/20:  20%|██        | 101/500 [00:08<00:30, 13.06it/s]

Epoch [2/20], Batch [100], Q: 0.2860, Loss: 0.5441


Epoch 2/20:  40%|████      | 202/500 [00:16<00:24, 12.22it/s]

Epoch [2/20], Batch [200], Q: 0.2860, Loss: 0.5313


Epoch 2/20:  60%|██████    | 301/500 [00:24<00:14, 13.44it/s]

Epoch [2/20], Batch [300], Q: 0.2860, Loss: 0.4968


Epoch 2/20:  80%|████████  | 401/500 [00:32<00:10,  9.70it/s]

Epoch [2/20], Batch [400], Q: 0.2860, Loss: 0.4546


                                                             

Epoch [2/20], Batch [500], Q: 0.2860, Loss: 0.4596
Epoch [2/20] Completed, Learning Rate: 0.001000


Epoch 3/20:  24%|██▍       | 102/417 [00:10<00:26, 11.83it/s]

Epoch [3/20], Batch [100], Q: 0.2720, Loss: 0.4189


Epoch 3/20:  48%|████▊     | 202/417 [00:19<00:21, 10.20it/s]

Epoch [3/20], Batch [200], Q: 0.2720, Loss: 0.3916


Epoch 3/20:  72%|███████▏  | 302/417 [00:28<00:09, 11.69it/s]

Epoch [3/20], Batch [300], Q: 0.2720, Loss: 0.3646


Epoch 3/20:  96%|█████████▌| 401/417 [00:37<00:01, 10.67it/s]

Epoch [3/20], Batch [400], Q: 0.2720, Loss: 0.3789


                                                             

Epoch [3/20] Completed, Learning Rate: 0.001000


Epoch 4/20:  28%|██▊       | 102/358 [00:10<00:24, 10.37it/s]

Epoch [4/20], Batch [100], Q: 0.2580, Loss: 0.3320


Epoch 4/20:  56%|█████▌    | 201/358 [00:20<00:19,  8.07it/s]

Epoch [4/20], Batch [200], Q: 0.2580, Loss: 0.3438


Epoch 4/20:  84%|████████▍ | 300/358 [00:30<00:05,  9.96it/s]

Epoch [4/20], Batch [300], Q: 0.2580, Loss: 0.3110


                                                             

Epoch [4/20] Completed, Learning Rate: 0.001000


Epoch 5/20:  32%|███▏      | 100/313 [00:11<00:22,  9.59it/s]

Epoch [5/20], Batch [100], Q: 0.2440, Loss: 0.2882


Epoch 5/20:  64%|██████▍   | 201/313 [00:24<00:12,  9.29it/s]

Epoch [5/20], Batch [200], Q: 0.2440, Loss: 0.2781


Epoch 5/20:  96%|█████████▋| 302/313 [00:36<00:01,  9.48it/s]

Epoch [5/20], Batch [300], Q: 0.2440, Loss: 0.2734


                                                             

Epoch [5/20] Completed, Learning Rate: 0.001000


Epoch 6/20:  36%|███▋      | 101/278 [00:11<00:18,  9.75it/s]

Epoch [6/20], Batch [100], Q: 0.2300, Loss: 0.2644


Epoch 6/20:  72%|███████▏  | 201/278 [00:23<00:09,  8.22it/s]

Epoch [6/20], Batch [200], Q: 0.2300, Loss: 0.2516


                                                             

Epoch [6/20] Completed, Learning Rate: 0.001000


Epoch 7/20:  40%|████      | 101/250 [00:13<00:17,  8.30it/s]

Epoch [7/20], Batch [100], Q: 0.2160, Loss: 0.2426


Epoch 7/20:  80%|████████  | 201/250 [00:27<00:12,  4.00it/s]

Epoch [7/20], Batch [200], Q: 0.2160, Loss: 0.2391


                                                             

Epoch [7/20] Completed, Learning Rate: 0.001000


Epoch 8/20:  44%|████▍     | 101/228 [00:20<00:25,  5.06it/s]

Epoch [8/20], Batch [100], Q: 0.2020, Loss: 0.2237


Epoch 8/20:  88%|████████▊ | 200/228 [00:42<00:05,  4.91it/s]

Epoch [8/20], Batch [200], Q: 0.2020, Loss: 0.2138


                                                             

Epoch [8/20] Completed, Learning Rate: 0.001000


Epoch 9/20:  48%|████▊     | 100/209 [00:21<00:22,  4.79it/s]

Epoch [9/20], Batch [100], Q: 0.1880, Loss: 0.1885


Epoch 9/20:  96%|█████████▌| 200/209 [00:44<00:01,  4.86it/s]

Epoch [9/20], Batch [200], Q: 0.1880, Loss: 0.1923


                                                             

Epoch [9/20] Completed, Learning Rate: 0.001000


Epoch 10/20:  52%|█████▏    | 100/193 [00:25<00:28,  3.27it/s]

Epoch [10/20], Batch [100], Q: 0.1740, Loss: 0.1730


                                                              

Epoch [10/20] Completed, Learning Rate: 0.000100


Epoch 11/20:  56%|█████▌    | 100/179 [00:29<00:21,  3.66it/s]

Epoch [11/20], Batch [100], Q: 0.1600, Loss: 0.1452


                                                              

Epoch [11/20] Completed, Learning Rate: 0.000100


Epoch 12/20:  60%|█████▉    | 100/167 [00:27<00:14,  4.56it/s]

Epoch [12/20], Batch [100], Q: 0.1460, Loss: 0.1251


                                                              

Epoch [12/20] Completed, Learning Rate: 0.000100


Epoch 13/20:  64%|██████▎   | 100/157 [00:31<00:17,  3.23it/s]

Epoch [13/20], Batch [100], Q: 0.1320, Loss: 0.1062


                                                              

Epoch [13/20] Completed, Learning Rate: 0.000100


Epoch 14/20:  64%|██████▎   | 100/157 [00:33<00:17,  3.33it/s]

Epoch [14/20], Batch [100], Q: 0.1180, Loss: 0.1003


                                                              

Epoch [14/20] Completed, Learning Rate: 0.000100


Epoch 15/20:  64%|██████▎   | 100/157 [00:30<00:17,  3.32it/s]

Epoch [15/20], Batch [100], Q: 0.1040, Loss: 0.0893


                                                              

Epoch [15/20] Completed, Learning Rate: 0.000100


Epoch 16/20:  64%|██████▎   | 100/157 [00:33<00:19,  2.93it/s]

Epoch [16/20], Batch [100], Q: 0.0900, Loss: 0.0864


                                                              

Epoch [16/20] Completed, Learning Rate: 0.000100


Epoch 17/20:  64%|██████▎   | 100/157 [00:31<00:17,  3.33it/s]

Epoch [17/20], Batch [100], Q: 0.0760, Loss: 0.0825


                                                              

Epoch [17/20] Completed, Learning Rate: 0.000100


Epoch 18/20:  64%|██████▎   | 100/157 [00:30<00:17,  3.34it/s]

Epoch [18/20], Batch [100], Q: 0.0620, Loss: 0.0783


                                                              

Epoch [18/20] Completed, Learning Rate: 0.000100


Epoch 19/20:  64%|██████▎   | 100/157 [00:32<00:18,  3.12it/s]

Epoch [19/20], Batch [100], Q: 0.0480, Loss: 0.0778


                                                              

Epoch [19/20] Completed, Learning Rate: 0.000100


Epoch 20/20:  64%|██████▎   | 100/157 [00:32<00:17,  3.35it/s]

Epoch [20/20], Batch [100], Q: 0.0340, Loss: 0.0758


                                                              

Epoch [20/20] Completed, Learning Rate: 0.000010
Model saved as ./CompetitionData1/Round1NET11721400377.7985215.pth
