In [2]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dltime.base.layers import ConvBlock
import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import torch
from dltime.models.FCN import FCN
from tqdm import tqdm
%matplotlib inline

In [None]:
data_for_train = ['zwy', 'zwy2', 'zwy3', 'zwy4', 'zwy5', 'j11', 'j11_2', 'j11_md', 'j11_527', 'yqcc', 'yqcc2', 'syf', 'syf2', 'sky', 'sky2', 'sky3', 'zyq', 'zyq2']
param_dict = {}
for data_name in tqdm(data_for_train):
    param_dict[data_name] = []
    named_params = torch.load(f"./outputs/{data_name}_FCN.pth")
    for n, p in named_params.items():
        if 'conv1d.weight' in n:
            param_dict[data_name].append(p.detach().clone().cpu().numpy())

In [9]:
class FCN(nn.Module):
    '''FCN'''
    def __init__(self, c_in, c_out, layers=[128, 256, 128], kss=[7, 5, 3], clf=True):
        super(FCN, self).__init__()
        self.clf = clf  # 是否作为分类器

        self.convblock1 = ConvBlock(c_in, layers[0], ks=kss[0])
        self.convblock2 = ConvBlock(layers[0], layers[1], ks=kss[1])
        self.convblock3 = ConvBlock(layers[1], layers[2], ks=kss[2])
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(layers[-1], c_out)

    def forward(self, x):
        x = self.convblock1(x)
        x1 = self.convblock2(x)
        x2 = self.convblock3(x1)
        x3 = self.gap(x2).squeeze(-1)
        x3 = self.fc(x3)
        return F.softmax(x2, dim=-1), torch.cat([x, x1, x2], dim=1)

In [4]:
from dltime.data.ts_datasets import Soil_Dataset
from utils import load_pkl
from data_process import handle_dataset_3dims
import torch.nn as nn
from torch.utils.data import DataLoader

In [5]:
import torch.fft

In [11]:
def infer_fn(valid_loader, model, device):
    model.eval()
    preds = []
    feature_maps = []
    for _, item in enumerate(valid_loader):
        for k, v in item.items():
            item[k] = v.to(device)

        labels = item['label']
        with torch.no_grad():
            y_preds, feature_map = model(item['input'])

        preds.append(y_preds.detach().cpu().numpy())
        feature_maps.append(feature_map.detach().cpu())

    
    predictions = np.concatenate(preds)
    feature_maps = torch.cat(feature_maps)
    feature_list = []
    for i in range(feature_maps.size(0)):
        map = feature_maps[i].unsqueeze(3)
        b, c, _ ,_= map.size()
        x_fft = torch.fft.fftshift(map)
        y_11 = nn.AdaptiveAvgPool2d(1)(x_fft[:,:,:]).view(b,c,1,1)
        y_22=nn.AdaptiveMaxPool2d(1)(abs(x_fft[:,:,:])).view(b,c,1,1)
        y_2=y_22/y_11
        feature_list.append(y_2)
    
    return predictions, feature_list

In [12]:
# data_for_train = ['zwy', 'zwy2', 'zwy3', 'zwy4', 'zwy5', 'j11', 'j11_2', 'j11_md', 'j11_527', 'yqcc', 'yqcc2', 'syf', 'syf2', 'sky', 'sky2', 'sky3', 'zyq', 'zyq2']
# data_for_train = ['sky', 'sky2', 'sky3']
data_for_train = ['zwy', 'zwy2', 'zwy3', 'zwy4', 'zwy5']
map_dict = {}
label_dict = {}
# train_data, test_data = [], []
for data_name in tqdm(data_for_train):
    model = FCN(c_in=5, c_out=3, layers=[64, 128, 64]).to('cuda')
    model.load_state_dict(torch.load(f'.\outputs\{data_name}_FCN.pth'))
    train_data = load_pkl(f'./pickle_data/{data_name}_train_64.pkl')
    test_data = load_pkl(f'./pickle_data/{data_name}_test_64.pkl')
    total_data = train_data + test_data
    total_x, total_label = handle_dataset_3dims(train_data, mode="all")
    total_x = np.swapaxes(total_x, 2, 1)
    total_dataset = Soil_Dataset(total_x, total_label, normalize=None, channel_first=True)
    total_dataloader = DataLoader(total_dataset, batch_size=16, shuffle=False, drop_last=False)

    label_dict[data_name] = total_label[:] 
    pred, maps = infer_fn(total_dataloader, model, 'cuda')
    map_dict[data_name] = maps[:]

  0%|          | 0/5 [00:12<?, ?it/s]


ValueError: too many values to unpack (expected 2)

In [18]:
map_dict['zwy'].shape

TypeError: 'tuple' object is not callable