# Problem

考虑某一股票的价格P，每过一段时间我们能获得一对数(t, Pt)，分别为当前的时间戳t和当前该股票的价格Pt，需要计算该股票在[t-W, t]时间段内的平均价格，
其中W是事先给定的窗口大小，单位和t相同。

请编写程序来实现这个计算过程。接口参见以下Interface节。

算法需要满足以下要求：

1. 相邻数据点之间的间隔不固定，可能有较大变化，在计算的时候应考虑数据间隔变化极端的情况下的鲁棒性;
2. 计算的平均价格应能较好反映股票价格在窗口时间内的“平均水平”
3. 假设在相邻两次数据到达期间内，股票的真实价格都等于前一次数据的价格
4. 所谓平均水平指的是股票价格在窗口内的简单移动平均（SMA）而不是指数移动平均（EMA）
5. 可用的内存大小是有限的，为O(num_bin)规模，当窗口内到达的数据量超过可用的存储空间时，应做一定的取舍，既不突破内存限制，又尽可能保证结果的精确性（只是简单地把尾部数据删除是不行的）
6. 实现以下MovingAverage接口
7. 尽量使用性能较优的算法和实现方法

编程语言：不限

# Note
====

num_bin: 使用的内存大小应为O(num_bin)级别
window：移动平均的窗口长度

timestamp是时间戳，浮点数，数据间隔不固定，可能一秒来几十个数据，也可能几小时来一个数据

# Code

## 程序代码

In [39]:
import struct
class MovingAverage:
    def __init__(self, num_bin: int, window: float):
        self.queue = []
        self.max_len = num_bin
        self.window = window

    def combine_floats(self, a, b):
        # 将a和b转换为二进制字符串
        a_bin = bin(struct.unpack('!I', struct.pack('!f', a))[0])[2:].zfill(32)
        b_bin = bin(struct.unpack('!I', struct.pack('!f', b))[0])[2:].zfill(32)
        # 将a和b的二进制字符串组合在一起
        combined_bin = a_bin + b_bin
        # 将组合后的二进制字符串转换为float64类型的值
        combined = struct.unpack('!d', struct.pack('!Q', int(combined_bin, 2)))[0]
        return combined

    def separate_floats(self, combined):
        # 将组合后的float64类型的值转换为二进制字符串
        combined_bin = bin(struct.unpack('!Q', struct.pack('!d', combined))[0])[2:].zfill(64)
        # 将二进制字符串分成两个32位的字符串
        a_bin = combined_bin[:32]
        b_bin = combined_bin[32:]
        # 将两个二进制字符串转换为float32类型的值
        a = struct.unpack('!f', struct.pack('!I', int(a_bin, 2)))[0]
        b = struct.unpack('!f', struct.pack('!I', int(b_bin, 2)))[0]
        return a, b

    def combine_ints(self, a, b):
        # 将a和b转换为二进制字符串
        a_bin = bin(a & 0xffffffff)[2:].zfill(32)
        b_bin = bin(b & 0xffffffff)[2:].zfill(32)
        # 将a和b的二进制字符串组合在一起
        combined_bin = a_bin + b_bin
        # 将组合后的二进制字符串转换为int64类型的值
        combined = int(combined_bin, 2)
        return combined

    def separate_ints(self, combined):
        # 将组合后的int64类型的值转换为二进制字符串
        combined_bin = bin(combined & 0xffffffffffffffff)[2:].zfill(64)
        # 将二进制字符串分成两个32位的字符串
        a_bin = combined_bin[:32]
        b_bin = combined_bin[32:]
        # 将两个二进制字符串转换为int32类型的值
        a = int(a_bin, 2)
        b = int(b_bin, 2)
        return a, b

    def check(self, timestamp: float):
        # 检查是否有过期数据，有则删除
        del_idx = -1
        M = len(self.queue)
        for k in range(2*M):
            i = k // 2
            j = k % 2
            time = self.separate_floats(self.queue[i][0])[j]
            if(time < timestamp - self.window):
                del_idx = i
            else:
                break
        self.queue = self.queue[del_idx:]
        print(self.queue)
        if(self.separate_floats(self.queue[0][0])[1] < timestamp - self.window):
            self.queue = self.queue[1:]
        else:
            M = len(self.queue)
            for i in range(M):
                if(i < M - 1):
                    time1 = self.separate_floats(self.queue[i][0])[1]
                    time2 = self.separate_floats(self.queue[i + 1][0])[0]
                    value1 = self.separate_floats(self.queue[i][1])[1]
                    value2 = self.separate_floats(self.queue[i + 1][1])[0]
                    weight1 = self.separate_ints(self.queue[i][2])[1]
                    weight2 = self.separate_ints(self.queue[i + 1][2])[0]
                    time1 = self.combine_floats(time1, time2)
                    value1 = self.combine_floats(value1, value2)
                    weight1 = self.combine_ints(weight1, weight2)
                    self.queue[i] = (time1, value1, weight1)
                else:
                    time1 = self.separate_floats(self.queue[i][0])[1]
                    value1 = self.separate_floats(self.queue[i][1])[1]
                    weight1 = self.separate_ints(self.queue[i][2])[1]
                    time1 = self.combine_floats(time1, 0.0)
                    value1 = self.combine_floats(value1, 0.0)
                    weight1 = self.combine_ints(weight1, 0)
                    self.queue[i] = (time1, value1, weight1)
        if(self.separate_ints(self.queue[-1][2])[0] == 0):
            self.queue.pop(-1)

    def merge(self):
        # 合并在时间上距离最近的两个数据
        if(len(self.queue) == self.max_len):
            if(self.separate_ints(self.queue[-1][2])[1] != 0):
                min_time_delta = 0x7fffffff
                M = len(self.queue)
                kdx = -1
                jdx = -1
                for k in range(2*M - 1):
                    i = k // 2
                    j = k % 2
                    time1 = self.separate_floats(self.queue[i][0])[j]
                    time2 = self.separate_floats(self.queue[(k+1)//2][0])[(j+1)%2]
                    time_delta = time2 - time1
                    if(time_delta < min_time_delta):
                        min_time_delta = time_delta
                        kdx = k
                        jdx = j
                weight1 = self.separate_ints(self.queue[kdx//2][2])[jdx]
                weight2 = self.separate_ints(self.queue[(kdx + 1)//2][2])[(jdx + 1)%2]
                fake_weight = weight1 + weight2
                time1 = self.separate_floats(self.queue[kdx//2][0])[jdx]
                time2 = self.separate_floats(self.queue[(kdx + 1)//2][0])[(jdx + 1)%2]
                fake_time = (time1*weight1 + time2*weight2)/fake_weight
                value1 = self.separate_floats(self.queue[kdx//2][1])[jdx]
                value2 = self.separate_floats(self.queue[(kdx + 1)//2][1])[(jdx + 1)%2]
                fake_value = (value1*weight1 + value2*weight2)/fake_weight

                if(jdx == 0):
                    fake_time = self.combine_floats(0.0, fake_time)
                    fake_value = self.combine_floats(0.0, fake_value)
                    fake_weight = self.combine_ints(0, fake_weight)
                    self.queue[kdx//2] = (fake_time, fake_value, fake_weight)
                else:
                    fake_time = self.combine_floats(self.separate_floats(self.queue[kdx//2][0])[0], fake_time)
                    fake_value = self.combine_floats(self.separate_floats(self.queue[kdx//2][1])[0], fake_value)
                    fake_weight = self.combine_ints(self.separate_ints(self.queue[kdx//2][2])[0], fake_weight)
                    self.queue[(kdx)//2] = (fake_time, fake_value, fake_weight)
                
                if(jdx == 0):
                    start_i = kdx//2
                else:
                    start_i = kdx//2 + 1
                M = len(self.queue)
                for i in range(start_i, M):
                    if(i < M - 1):
                        time1 = self.separate_floats(self.queue[i][0])[1]
                        time2 = self.separate_floats(self.queue[i + 1][0])[0]
                        value1 = self.separate_floats(self.queue[i][1])[1]
                        value2 = self.separate_floats(self.queue[i + 1][1])[0]
                        weight1 = self.separate_ints(self.queue[i][2])[1]
                        weight2 = self.separate_ints(self.queue[i + 1][2])[0]
                        time1 = self.combine_floats(time1, time2)
                        value1 = self.combine_floats(value1, value2)
                        weight1 = self.combine_ints(weight1, weight2)
                        self.queue[i] = (time1, value1, weight1)
                    else:
                        time1 = self.separate_floats(self.queue[i][0])[1]
                        value1 = self.separate_floats(self.queue[i][1])[1]
                        weight1 = self.separate_ints(self.queue[i][2])[1]
                        time1 = self.combine_floats(time1, 0.0)
                        value1 = self.combine_floats(value1, 0.0)
                        weight1 = self.combine_ints(weight1, 0)
                        self.queue[i] = (time1, value1, weight1)

    def Get(self) -> float:
        # 返回当前计算的平均值
        M = len(self.queue)
        sma = 0
        weight = 0
        for k in range(2*M):
            i = k//2
            j = k%2
            w = self.separate_ints(self.queue[i][2])[j]
            sma += self.separate_floats(self.queue[i][1])[j]*w
            weight += w
        return sma/weight
        
    def Update(self, timestamp: float, value: float):
        # 新数据到达，更新状态
        if(len(self.queue) > 0):
            self.check(timestamp)
        self.merge()

        if(len(self.queue) == 0):
            time = self.combine_floats(timestamp, 0.0)
            value = self.combine_floats(value, 0.0)
            weight = self.combine_ints(1, 0)
            self.queue.append((time, value, weight))
        else:
            tmp_time, tmp_value, tmp_weight = self.queue[-1]
            if(self.separate_ints(tmp_weight)[1] == 0):
                tmp_time = self.separate_floats(tmp_time)[0]
                tmp_value = self.separate_floats(tmp_value)[0]
                tmp_weight = self.separate_ints(tmp_weight)[0]
                time = self.combine_floats(tmp_time, timestamp)
                value = self.combine_floats(tmp_value, value)
                weight = self.combine_ints(tmp_weight, 1)
                self.queue[-1] = (time, value, weight)
            else:
                time = self.combine_floats(timestamp, 0.0)
                value = self.combine_floats(value, 0.0)
                weight = self.combine_ints(1, 0)
                self.queue.append((time, value, weight))


## 模拟数据实验

In [40]:
import numpy as np

np.random.seed(1024)

In [41]:
data_size = 1000
times = np.random.chisquare(1, data_size)
times = np.cumsum(times)
values = np.random.normal(10, 1, data_size)


W = 5
num_bin = 3

In [42]:
ma_robot = MovingAverage(num_bin, W)

correct_num = 0
mse = 0
for i in range(data_size):
    sma = 0
    for j in range(i, -1, -1):
        if(times[i] - times[j] > W):
            sma = np.mean(values[j + 1:i + 1])
            break
        if(j == 0):
            sma = np.mean(values[:i + 1])
    
    ma_robot.Update(times[i], values[i])
    cal_sma = ma_robot.Get()
    mse += (sma - cal_sma)**2
    print("sma: %-20s cal_sma: %-20s"%(sma, cal_sma), end=" ")
    if(abs(sma - cal_sma) < 1e-2):
        correct_num += 1
        print("correct!")
    else:
        print("wrong!")
mse /= data_size
print("correct rate = %-20s MSE = %-20s"%(correct_num/data_size, mse))

sma: 9.774049408634541    cal_sma: 9.774049758911133    correct!
[(0.003090161830186844, 465056.5, 4294967296)]
sma: 10.246851985717374   cal_sma: 10.71965503692627    wrong!
[(0.10022586584091187, 901594.5, 4294967296)]
sma: 9.751819715346498    cal_sma: 8.761754989624023    wrong!
[(3.565073013305664, 230916.75, 4294967296)]
sma: 9.673540423703816    cal_sma: 9.438702583312988    wrong!
[(3.922719955444336, 377147.25, 4294967296)]
sma: 9.77213905288551     cal_sma: 10.166533470153809   wrong!
[(29.974777221679688, 611599.5, 4294967296)]
sma: 9.969770199876853    cal_sma: 10.957925796508789   wrong!
[(93.04083251953125, 1026517.0, 4294967296)]


IndexError: list index out of range