In [8]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [11]:
import os
import re

folder_path = '/content/train'  # 文件夹路径
files = os.listdir(folder_path)  # 获取文件夹中的所有文件

x = [] 
y = []

for file in files:
    if file.endswith('.txt'):  # 仅处理以 .txt 结尾的文件
        file_path = os.path.join(folder_path, file)  # 构建文件的完整路径
        # 读取文件中的数据，假设每个文件包含一个二维数组，可以使用 numpy 的 loadtxt 函数
        data = np.loadtxt(file_path)
        x.append(data)
 # 提取文件名中的最后一个数字
    match = re.search(r"-([\d]+)\.txt$", file)
    if match:
        # 将提取到的数字添加到数组中
        y.append(int(match.group(1)))

# 将 x 转换为 numpy 数组
x = np.array(x)
y = np.array(y)


In [12]:
y


array([4, 2, 4, 3, 2, 4, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 3, 2,
       2, 2, 4, 2, 4, 2, 2, 2, 2, 2, 2, 1, 4, 3, 2, 2, 4, 2, 2, 2, 4, 2,
       2, 4, 4, 2, 2, 2, 2, 3, 2, 2, 2, 4, 3, 2, 4, 3, 2, 2, 3, 2, 2, 2,
       2, 3, 2, 2, 3, 3, 4, 2, 4, 2, 4, 3, 2, 2, 4, 2, 3, 2, 2, 4, 2, 4,
       2, 2, 3, 4, 4, 2, 4, 4, 2, 2, 2, 2, 4, 2, 1, 2, 2, 2, 2, 2, 3, 4,
       2, 2, 3, 2, 2, 3, 2, 4, 4, 2, 4, 2, 2, 4, 4, 2, 2, 2, 2, 3, 4, 2,
       4, 4, 2, 2, 2, 4, 4, 3, 2, 3, 3, 2, 2, 2, 4, 4, 3, 4, 2, 2, 2, 3,
       2, 2, 3, 4, 2, 2, 4, 3, 4, 2, 2, 2, 2, 4, 1, 2, 4, 4, 2, 2, 2, 2,
       3, 2, 2, 3, 1, 3, 2, 4, 2, 4, 2, 2, 2, 4, 2, 4, 4, 2, 2, 2, 2, 2,
       2, 3, 4, 4, 1, 2, 3, 4, 2, 2, 3, 2, 2, 2, 3, 4, 2, 2, 4, 2, 3, 4,
       2, 2, 2, 3, 3, 2, 2, 2, 2, 2, 2, 3, 3, 4, 4, 2, 3, 2, 2, 2, 3, 2,
       2, 2, 4, 2, 1, 2, 2, 4, 4, 2, 3, 4, 2, 3, 3, 4, 2, 2, 3, 2, 2, 3,
       2, 4, 2, 2, 4, 4, 2, 2, 4, 2, 3, 2, 4, 2, 3, 2, 3, 3, 2, 3, 2, 4,
       2, 3, 2, 2, 2, 3, 3, 2, 2, 2, 2, 2, 2, 4, 4,

In [13]:
x.shape

(818, 12501, 2)

In [14]:
x_train = x[:,:,1]
num_classes = 5
y_train = jax.nn.one_hot(y, num_classes)
y_train.shape

(818, 5)

In [15]:

array = x_train

# 提取每一行最大值前后共800个点的数值
window_size = 500

max_values = np.max(array, axis=1)  # 沿着第二个轴（12501维）计算最大值
max_indices = np.argmax(array, axis=1)  # 沿着第二个轴（12501维）计算最大值的索引
start_indices = np.maximum(0, max_indices - window_size // 2)
end_indices = np.minimum(12500, max_indices + window_size // 2)
end_indices = np.where(start_indices == 0, window_size, end_indices)  # 如果起始索引为0，设置为0
start_indices = np.where(end_indices == 12500, 12500 - window_size, start_indices)  # 如果结束索引为12500，设置为12500 - window_size
new_array = np.array([array[i, start_indices[i]:end_indices[i]+1] for i in range(array.shape[0])])


x_train = new_array
x_train.shape

(818, 501)

In [16]:
def make_network(layer_sizes):
    
    def init(key, scale=1e-2):
        params = []
        for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]):
            weight_key, bias_key = jax.random.split(key)
            weight = scale * jax.random.normal(weight_key, (n_in, n_out))
            bias = scale * jax.random.normal(bias_key, (n_out,))
            params.append((weight, bias))
        return params

    def relu(x):
        return jnp.maximum(0, x)

    def apply(params, x):
        for w, b in params[:-1]:
            x = relu(jnp.dot(x, w) + b)
        final_w, final_b = params[-1]
        return jnp.dot(x, final_w) + final_b

    return init, apply


In [22]:
layer_sizes = [501, 128, 64, 5] # IN: 801 numbers, OUT: 5 classes
init_fn, apply_fn = make_network(layer_sizes)
key = jax.random.PRNGKey(42)
params = init_fn(key)


In [23]:
from jax.flatten_util import ravel_pytree
ravel_pytree(params)[0].size

72837

In [24]:
def cross_entropy(params, x, y):
    logits = apply_fn(params, x)
    return jnp.sum(y * jax.nn.log_softmax(logits))

def cross_entropy_loss(params, x, y):
    return -jnp.mean(jax.vmap(cross_entropy, (None, 0, 0),0)(params, x, y))

In [25]:
# Define the update function
@jax.jit
def update(params, x, y, learning_rate):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads) 

# Define the accuracy function
def accuracy(params, x, y):
    predictions = jnp.argmax(apply_fn(params, x), axis=1)
    actual = jnp.argmax(y, axis=1)
    return jnp.mean(predictions == actual)

In [28]:
learning_rate = 0.01
num_epochs = 10
batch_size = 8

train_size = x_train.shape[0]
num_complete_batches, leftover = divmod(train_size, batch_size)
num_batches = num_complete_batches + bool(leftover)

for epoch in range(num_epochs):
    # Shuffle the training data
    key, subkey = jax.random.split(key)
    permutation = jax.random.permutation(subkey, train_size)
    x_train = x_train[permutation]
    y_train = y_train[permutation]

    for i in range(num_batches):
        # Get batch data
        batch_start = i * batch_size
        batch_end = (i + 1) * batch_size
        x_batch = x_train[batch_start:batch_end]
        y_batch = y_train[batch_start:batch_end]

        # Update parameters
        params = update(params, x_batch, y_batch, learning_rate)

    # Compute accuracy on training and test sets
    train_accuracy = accuracy(params, x_train, y_train)

    print(f"Epoch {epoch}: train accuracy = {train_accuracy:.3f}")

Epoch 0: train accuracy = 0.546
Epoch 1: train accuracy = 0.546
Epoch 2: train accuracy = 0.546
Epoch 3: train accuracy = 0.546
Epoch 4: train accuracy = 0.546
Epoch 5: train accuracy = 0.546
Epoch 6: train accuracy = 0.546
Epoch 7: train accuracy = 0.546
Epoch 8: train accuracy = 0.546
Epoch 9: train accuracy = 0.546


In [62]:

predictions = jnp.argmax(apply_fn(params, x_train), axis=1)
actual = jnp.argmax(y_train, axis=1)
idx = jnp.where(predictions != actual)[0]
len(idx)

371

In [63]:
for i in idx:
  print(files[idx[1]])

data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V 700baiμA4400mV -3.txt
data A V

In [60]:
predictions[idx]


Array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,